import streamlit as st
import torch
from model_handler import load_model_and_tokenizer, generate_response
import time
# Page configuration
st.set_page_config(
page_title="Phi-2 Fine-tuned Assistant",
page_icon="🤖",
layout="centered"
)
# Custom CSS for chat interface
st.markdown("""
""", unsafe_allow_html=True)
# Initialize session state for chat history
if "messages" not in st.session_state:
st.session_state.messages = []
if "model_loaded" not in st.session_state:
st.session_state.model_loaded = False
# App title - made smaller to reduce vertical space
st.markdown("
Phi-2 Fine-tuned Assistant
", unsafe_allow_html=True)
# Load model (only once)
if not st.session_state.model_loaded:
with st.spinner("Loading the fine-tuned model... This may take a minute."):
model, tokenizer = load_model_and_tokenizer()
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.session_state.model_loaded = True
st.success("Model loaded successfully!")
# Display chat messages
st.markdown('', unsafe_allow_html=True)
for message in st.session_state.messages:
role = message["role"]
content = message["content"]
if role == "user":
st.markdown(f'
{content}
', unsafe_allow_html=True)
else:
st.markdown(f'
{content}
', unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
# Chat input
with st.container():
col1, col2 = st.columns([5, 1])
with col1:
user_input = st.text_input("", key="user_input", placeholder="Type your message here...")
with col2:
clear_button = st.button("Clear")
# Process input when user submits a message
if user_input:
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_input})
# Create a placeholder for the assistant's response
assistant_response_placeholder = st.empty()
# Generate streaming response
full_response = ""
# Display "Assistant is typing..." message
with assistant_response_placeholder.container():
st.markdown('Assistant is typing...
', unsafe_allow_html=True)
# Generate response with streaming
for token in generate_response(
st.session_state.model,
st.session_state.tokenizer,
user_input # Changed to pass just the current input
):
full_response += token
# Update the response in real-time
with assistant_response_placeholder.container():
st.markdown(f'{full_response}
', unsafe_allow_html=True)
time.sleep(0.01) # Small delay to make streaming visible
# Add assistant's response to chat history
st.session_state.messages.append({"role": "assistant", "content": full_response})
# Clear the input box
st.session_state.user_input = ""
# Rerun to update the UI
st.experimental_rerun()
# Clear chat when button is pressed
if clear_button:
st.session_state.messages = []
st.experimental_rerun()