import json import os import streamlit as st import streamlit.components.v1 as components from langchain.callbacks import get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.embeddings import GPT4AllEmbeddings from langchain.llms import OpenAI from chat_history import insert_chat_history, insert_chat_history_articles from connection import connect from css import load_css from message import Message from vector_store import CustomVectorStore st.set_page_config(layout="wide") st.title("Sorbobot - Le futur de la recherche scientifique interactive") chat_column, doc_column = st.columns([2, 1]) conn = connect() def initialize_session_state(): if "history" not in st.session_state: st.session_state.history = [] if "token_count" not in st.session_state: st.session_state.token_count = 0 if "conversation" not in st.session_state: embeddings = GPT4AllEmbeddings() db = CustomVectorStore( embedding_function=embeddings, table_name="article", column_name="abstract_embedding", connection=conn, ) retriever = db.as_retriever() llm = OpenAI( temperature=0, openai_api_key=os.environ["OPENAI_API_KEY"], model="text-davinci-003", ) memory = ConversationBufferMemory( output_key="answer", memory_key="chat_history", return_messages=True ) st.session_state.conversation = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, verbose=True, memory=memory, return_source_documents=True, ) def send_message_callback(): with st.spinner("Wait for it..."): with get_openai_callback() as cb: human_prompt = st.session_state.human_prompt.strip() if len(human_prompt) == 0: return llm_response = st.session_state.conversation(human_prompt) st.session_state.history.append(Message("human", human_prompt)) st.session_state.history.append( Message( "ai", llm_response["answer"], documents=llm_response["source_documents"], ) ) st.session_state.token_count += cb.total_tokens if os.environ.get("ENVIRONMENT") == "dev": history_id = insert_chat_history(conn, human_prompt, llm_response["answer"]) insert_chat_history_articles(conn, history_id, llm_response["source_documents"]) def exemple_message_callback_button(args): st.session_state.human_prompt = args send_message_callback() st.session_state.human_prompt = "" def clear_history(): st.session_state.history.clear() st.session_state.token_count = 0 st.session_state.conversation.memory.clear() load_css() initialize_session_state() exemples = [ "Who has published influential research on quantum computing?", "List any prominent authors in the field of artificial intelligence ethics?", "Who are the leading experts on climate change mitigation strategies?", ] with chat_column: chat_placeholder = st.container() prompt_placeholder = st.form("chat-form", clear_on_submit=True) information_placeholder = st.container() with chat_placeholder: for chat in st.session_state.history: div = f"""
​{chat.message}
""" st.markdown(div, unsafe_allow_html=True) for _ in range(3): st.markdown("") with prompt_placeholder: st.markdown("**Chat**") cols = st.columns((6, 1)) cols[0].text_input( "Chat", label_visibility="collapsed", key="human_prompt", ) cols[1].form_submit_button( "Submit", type="primary", on_click=send_message_callback, ) if st.session_state.token_count == 0: information_placeholder.markdown("### Test me !") for idx_exemple, exemple in enumerate(exemples): information_placeholder.button( exemple, key=f"{idx_exemple}_button", on_click=exemple_message_callback_button, args=(exemple,) ) st.button(":new: Start a new conversation", on_click=clear_history, type="secondary") information_placeholder.caption( f""" Used {st.session_state.token_count} tokens \n Debug Langchain conversation: {st.session_state.history} """ ) components.html( """ """, height=0, width=0, ) with doc_column: st.markdown("**Source documents**") if len(st.session_state.history) > 0: for doc in st.session_state.history[-1].documents: doc_content = json.loads(doc.page_content) expander = st.expander(doc_content["title"]) expander.markdown(f"**HalID** : https://hal.science/{doc_content['hal_id']}") expander.markdown(doc_content["abstract"]) expander.markdown(f"**Authors** : {doc_content['authors']}") expander.markdown(f"**Keywords** : {doc_content['keywords']}") expander.markdown(f"**Distance** : {doc_content['distance']}")