import streamlit as st import torch from model_handler import load_model_and_tokenizer, generate_response # Page configuration st.set_page_config( page_title="Phi-2 Fine-tuned Assistant", page_icon="🤖", layout="wide", initial_sidebar_state="collapsed" ) # Clean, minimal CSS st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "model" not in st.session_state: st.session_state.model = None if "tokenizer" not in st.session_state: st.session_state.tokenizer = None # Main app container st.markdown('
', unsafe_allow_html=True) st.markdown('

Phi-2 Fine-tuned Assistant

', unsafe_allow_html=True) # Load model if not already loaded if st.session_state.model is None: st.markdown('
Loading model... This may take a minute.
', unsafe_allow_html=True) try: model, tokenizer = load_model_and_tokenizer() st.session_state.model = model st.session_state.tokenizer = tokenizer st.rerun() except Exception as e: st.error(f"Error loading model: {str(e)}") st.stop() # Display chat messages for message in st.session_state.messages: if message["role"] == "user": st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) else: st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) # Input box (outside form to keep state during generation) user_input = st.text_input("Message:", key="user_input", placeholder="Type your message here...") # Submit and Clear buttons col1, col2 = st.columns([4, 1]) submit = col1.button("Send") clear = col2.button("Clear Chat") # Handle Clear button if clear: st.session_state.messages = [] st.rerun() # Handle user input + generate response if submit and user_input: st.session_state.messages.append({"role": "user", "content": user_input}) with st.spinner("Assistant is typing..."): response = "" response_placeholder = st.empty() try: for token in generate_response( st.session_state.model, st.session_state.tokenizer, user_input ): response += token response_placeholder.markdown( f'
{response}
', unsafe_allow_html=True ) st.session_state.messages.append({"role": "assistant", "content": response}) except Exception as e: error_msg = f"Error: {str(e)}" st.session_state.messages.append({"role": "assistant", "content": error_msg}) st.error(error_msg) st.rerun()