import json import os import streamlit as st import streamlit.components.v1 as components from langchain.callbacks import get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.embeddings import GPT4AllEmbeddings from langchain.llms import OpenAI from chat_history import insert_chat_history, insert_chat_history_articles from css import load_css from custom_pgvector import CustomPGVector from message import Message from connection import connect st.set_page_config(layout="wide") st.title("Sorbobot - Le futur de la recherche scientifique interactive") chat_column, doc_column = st.columns([2, 1]) conn = connect() def initialize_session_state(): if "history" not in st.session_state: st.session_state.history = [] if "token_count" not in st.session_state: st.session_state.token_count = 0 if "conversation" not in st.session_state: embeddings = GPT4AllEmbeddings() db = CustomPGVector( embedding_function=embeddings, table_name="article", column_name="abstract_embedding", connection=conn, ) retriever = db.as_retriever() llm = OpenAI( temperature=0, openai_api_key=os.environ["OPENAI_API_KEY"], model="text-davinci-003", ) memory = ConversationBufferMemory( output_key="answer", memory_key="chat_history", return_messages=True ) st.session_state.conversation = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, verbose=True, memory=memory, return_source_documents=True, ) def on_click_callback(): with st.spinner('Wait for it...'): with get_openai_callback() as cb: human_prompt = st.session_state.human_prompt.strip() if len(human_prompt) == 0: return llm_response = st.session_state.conversation(human_prompt) st.session_state.history.append(Message("human", human_prompt)) st.session_state.history.append( Message( "ai", llm_response["answer"], documents=llm_response["source_documents"] ) ) st.session_state.token_count += cb.total_tokens # history_id = insert_chat_history(conn, human_prompt, llm_response["answer"]) # insert_chat_history_articles(conn, history_id, llm_response["source_documents"]) def clear_history(): st.session_state.history.clear() st.session_state.token_count = 0 st.session_state.conversation.memory.clear() load_css() initialize_session_state() with chat_column: chat_placeholder = st.container() prompt_placeholder = st.form("chat-form", clear_on_submit=True) information_placeholder = st.container() with chat_placeholder: for chat in st.session_state.history: div = f"""
​{chat.message}
""" st.markdown(div, unsafe_allow_html=True) for _ in range(3): st.markdown("") with prompt_placeholder: st.markdown("**Chat**") cols = st.columns((6, 1)) cols[0].text_input( "Chat", label_visibility="collapsed", key="human_prompt", ) cols[1].form_submit_button( "Submit", type="primary", on_click=on_click_callback, ) st.button(":new: Start a new conversation", on_click=clear_history, type="secondary") information_placeholder.caption( f""" Used {st.session_state.token_count} tokens \n Debug Langchain conversation: {st.session_state.history} """ ) components.html( """ """, height=0, width=0, ) with doc_column: st.markdown("**Source documents**") if len(st.session_state.history) > 0: for doc in st.session_state.history[-1].documents: doc_content = json.loads(doc.page_content) expander = st.expander(doc_content["title"]) expander.markdown(f"**DOI : {doc_content['doi']}**") expander.markdown(doc_content["abstract"]) expander.markdown(f"**Authors** : {doc_content['authors']}") expander.markdown(f"**Keywords** : {doc_content['keywords']}") expander.markdown(f"**Distance** : {doc_content['distance']}")