|
import streamlit as st |
|
import torch |
|
from model_handler import load_model_and_tokenizer, generate_response |
|
import time |
|
|
|
|
|
st.set_page_config( |
|
page_title="Phi-2 Fine-tuned Assistant", |
|
page_icon="🤖", |
|
layout="centered" |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.user-bubble { |
|
background-color: #e6f7ff; |
|
border-radius: 15px; |
|
padding: 10px 15px; |
|
margin: 5px 0; |
|
max-width: 80%; |
|
margin-left: auto; |
|
margin-right: 10px; |
|
position: relative; |
|
color: #000000; /* Ensure text is black */ |
|
} |
|
.assistant-bubble { |
|
background-color: #f0f0f0; |
|
border-radius: 15px; |
|
padding: 10px 15px; |
|
margin: 5px 0; |
|
max-width: 80%; |
|
margin-left: 10px; |
|
position: relative; |
|
color: #000000; /* Ensure text is black */ |
|
} |
|
.chat-container { |
|
display: flex; |
|
flex-direction: column; |
|
height: calc(100vh - 200px); |
|
overflow-y: auto; |
|
padding: 10px; |
|
margin-bottom: 20px; |
|
} |
|
.stTextInput>div>div>input { |
|
border-radius: 20px; |
|
} |
|
.stButton>button { |
|
border-radius: 20px; |
|
width: 100%; |
|
} |
|
/* Fix for excessive vertical space */ |
|
header { |
|
visibility: hidden; |
|
} |
|
.block-container { |
|
padding-top: 1rem; |
|
padding-bottom: 1rem; |
|
} |
|
h1 { |
|
margin-top: 0 !important; |
|
margin-bottom: 1rem !important; |
|
} |
|
/* Ensure dark mode compatibility */ |
|
.stApp { |
|
background-color: #121212; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if "model_loaded" not in st.session_state: |
|
st.session_state.model_loaded = False |
|
|
|
|
|
st.markdown("<h2>Phi-2 Fine-tuned Assistant</h2>", unsafe_allow_html=True) |
|
|
|
|
|
if not st.session_state.model_loaded: |
|
with st.spinner("Loading the fine-tuned model... This may take a minute."): |
|
model, tokenizer = load_model_and_tokenizer() |
|
st.session_state.model = model |
|
st.session_state.tokenizer = tokenizer |
|
st.session_state.model_loaded = True |
|
st.success("Model loaded successfully!") |
|
|
|
|
|
st.markdown('<div class="chat-container">', unsafe_allow_html=True) |
|
for message in st.session_state.messages: |
|
role = message["role"] |
|
content = message["content"] |
|
|
|
if role == "user": |
|
st.markdown(f'<div class="user-bubble">{content}</div>', unsafe_allow_html=True) |
|
else: |
|
st.markdown(f'<div class="assistant-bubble">{content}</div>', unsafe_allow_html=True) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
with st.container(): |
|
col1, col2 = st.columns([5, 1]) |
|
with col1: |
|
user_input = st.text_input("", key="user_input", placeholder="Type your message here...") |
|
with col2: |
|
clear_button = st.button("Clear") |
|
|
|
|
|
if user_input: |
|
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
assistant_response_placeholder = st.empty() |
|
|
|
|
|
full_response = "" |
|
|
|
|
|
with assistant_response_placeholder.container(): |
|
st.markdown('<div class="assistant-bubble">Assistant is typing...</div>', unsafe_allow_html=True) |
|
|
|
|
|
for token in generate_response( |
|
st.session_state.model, |
|
st.session_state.tokenizer, |
|
user_input |
|
): |
|
full_response += token |
|
|
|
with assistant_response_placeholder.container(): |
|
st.markdown(f'<div class="assistant-bubble">{full_response}</div>', unsafe_allow_html=True) |
|
time.sleep(0.01) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
st.session_state.user_input = "" |
|
|
|
|
|
st.experimental_rerun() |
|
|
|
|
|
if clear_button: |
|
st.session_state.messages = [] |
|
st.experimental_rerun() |
|
|