|
import streamlit as st |
|
import torch |
|
from model_handler import load_model_and_tokenizer, generate_response |
|
|
|
|
|
st.set_page_config( |
|
page_title="Phi-2 Fine-tuned Assistant", |
|
page_icon="π€", |
|
layout="wide", |
|
initial_sidebar_state="collapsed" |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
/* Remove all default padding and margins */ |
|
.block-container { |
|
padding: 0 !important; |
|
max-width: 100% !important; |
|
} |
|
|
|
/* Hide header and footer */ |
|
header, footer, #MainMenu { |
|
visibility: hidden; |
|
} |
|
|
|
/* Main container */ |
|
.main-container { |
|
max-width: 800px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
} |
|
|
|
/* Title */ |
|
.title { |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
|
|
/* Chat messages */ |
|
.user-message { |
|
background-color: #e6f7ff; |
|
border-radius: 15px; |
|
padding: 10px 15px; |
|
margin: 5px 0; |
|
max-width: 80%; |
|
margin-left: auto; |
|
margin-right: 10px; |
|
color: #000; |
|
} |
|
|
|
.assistant-message { |
|
background-color: #f0f0f0; |
|
border-radius: 15px; |
|
padding: 10px 15px; |
|
margin: 5px 0; |
|
max-width: 80%; |
|
margin-left: 10px; |
|
color: #000; |
|
} |
|
|
|
/* Input area */ |
|
.input-area { |
|
display: flex; |
|
margin-top: 20px; |
|
} |
|
|
|
/* Loading message */ |
|
.loading { |
|
text-align: center; |
|
padding: 20px; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if "model" not in st.session_state: |
|
st.session_state.model = None |
|
|
|
if "tokenizer" not in st.session_state: |
|
st.session_state.tokenizer = None |
|
|
|
|
|
st.markdown('<div class="main-container">', unsafe_allow_html=True) |
|
st.markdown('<h1 class="title">Phi-2 Fine-tuned Assistant</h1>', unsafe_allow_html=True) |
|
|
|
|
|
if st.session_state.model is None: |
|
st.markdown('<div class="loading">Loading model... This may take a minute.</div>', unsafe_allow_html=True) |
|
try: |
|
model, tokenizer = load_model_and_tokenizer() |
|
st.session_state.model = model |
|
st.session_state.tokenizer = tokenizer |
|
st.rerun() |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.stop() |
|
|
|
|
|
for message in st.session_state.messages: |
|
if message["role"] == "user": |
|
st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True) |
|
else: |
|
st.markdown(f'<div class="assistant-message">{message["content"]}</div>', unsafe_allow_html=True) |
|
|
|
|
|
user_input = st.text_input("Message:", key="user_input", placeholder="Type your message here...") |
|
|
|
|
|
col1, col2 = st.columns([4, 1]) |
|
submit = col1.button("Send") |
|
clear = col2.button("Clear Chat") |
|
|
|
|
|
if clear: |
|
st.session_state.messages = [] |
|
st.rerun() |
|
|
|
|
|
if submit and user_input: |
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
with st.spinner("Assistant is typing..."): |
|
response = "" |
|
response_placeholder = st.empty() |
|
|
|
try: |
|
for token in generate_response( |
|
st.session_state.model, |
|
st.session_state.tokenizer, |
|
user_input |
|
): |
|
response += token |
|
response_placeholder.markdown( |
|
f'<div class="assistant-message">{response}</div>', |
|
unsafe_allow_html=True |
|
) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
except Exception as e: |
|
error_msg = f"Error: {str(e)}" |
|
st.session_state.messages.append({"role": "assistant", "content": error_msg}) |
|
st.error(error_msg) |
|
|
|
st.rerun() |
|
|