import os import streamlit as st import streamlit.components.v1 as components from css import load_css from langchain import OpenAI from langchain.callbacks import get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores.pgvector import PGVector from message import Message CONNECTION_STRING = "postgresql+psycopg2://localhost/sorbobot" COLLECTION_NAME = "" def initialize_session_state(): if "history" not in st.session_state: st.session_state.history = [] if "token_count" not in st.session_state: st.session_state.token_count = 0 if "conversation" not in st.session_state: embeddings = OpenAIEmbeddings() store = PGVector( collection_name=COLLECTION_NAME, connection_string=CONNECTION_STRING, embedding_function=embeddings, ) retriever = store.as_retriever() llm = OpenAI( temperature=0, openai_api_key=os.environ["OPENAI_API_KEY"], model="text-davinci-003", ) st.session_state.memory = ConversationBufferMemory() st.session_state.conversation = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever ) def on_click_callback(): with get_openai_callback() as cb: human_prompt = st.session_state.human_prompt llm_response = st.session_state.conversation.run( { "question": human_prompt, "chat_history": st.session_state.memory.buffer, } ) st.session_state.history.append(Message("human", human_prompt)) st.session_state.history.append(Message("ai", llm_response)) st.session_state.token_count += cb.total_tokens load_css() initialize_session_state() st.title("Sorbobot - Le futur de la recherche scientifique interactive") chat_placeholder = st.container() prompt_placeholder = st.form("chat-form") information_placeholder = st.empty() with chat_placeholder: for chat in st.session_state.history: div = f"""
​{chat.message}
""" st.markdown(div, unsafe_allow_html=True) for _ in range(3): st.markdown("") with prompt_placeholder: st.markdown("**Chat**") cols = st.columns((6, 1)) cols[0].text_input( "Chat", value="Hello bot", label_visibility="collapsed", key="human_prompt", ) cols[1].form_submit_button( "Submit", type="primary", on_click=on_click_callback, ) information_placeholder.caption( f""" Used {st.session_state.token_count} tokens \n Debug Langchain conversation: {st.session_state.memory.buffer} """ ) components.html( """ """, height=0, width=0, )