Spaces:
Sleeping
Sleeping
import json | |
import os | |
import streamlit as st | |
import streamlit.components.v1 as components | |
from langchain.callbacks import get_openai_callback | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.chains.conversation.memory import ConversationBufferMemory | |
from langchain.embeddings import GPT4AllEmbeddings | |
from langchain.llms import OpenAI | |
from chat_history import insert_chat_history, insert_chat_history_articles | |
from connection import connect | |
from css import load_css | |
from message import Message | |
from vector_store import CustomVectorStore | |
st.set_page_config(layout="wide") | |
st.title("Sorbobot - Le futur de la recherche scientifique interactive") | |
chat_column, doc_column = st.columns([2, 1]) | |
conn = connect() | |
def initialize_session_state(): | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "token_count" not in st.session_state: | |
st.session_state.token_count = 0 | |
if "conversation" not in st.session_state: | |
embeddings = GPT4AllEmbeddings() | |
db = CustomVectorStore( | |
embedding_function=embeddings, | |
table_name="article", | |
column_name="abstract_embedding", | |
connection=conn, | |
) | |
retriever = db.as_retriever() | |
llm = OpenAI( | |
temperature=0, | |
openai_api_key=os.environ["OPENAI_API_KEY"], | |
model="text-davinci-003", | |
) | |
memory = ConversationBufferMemory( | |
output_key="answer", memory_key="chat_history", return_messages=True | |
) | |
st.session_state.conversation = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
verbose=True, | |
memory=memory, | |
return_source_documents=True, | |
) | |
def send_message_callback(): | |
with st.spinner("Wait for it..."): | |
with get_openai_callback() as cb: | |
human_prompt = st.session_state.human_prompt.strip() | |
if len(human_prompt) == 0: | |
return | |
llm_response = st.session_state.conversation(human_prompt) | |
st.session_state.history.append(Message("human", human_prompt)) | |
st.session_state.history.append( | |
Message( | |
"ai", | |
llm_response["answer"], | |
documents=llm_response["source_documents"], | |
) | |
) | |
st.session_state.token_count += cb.total_tokens | |
if os.environ.get("ENVIRONMENT") == "dev": | |
history_id = insert_chat_history(conn, human_prompt, llm_response["answer"]) | |
insert_chat_history_articles(conn, history_id, llm_response["source_documents"]) | |
def exemple_message_callback_button(args): | |
st.session_state.human_prompt = args | |
send_message_callback() | |
st.session_state.human_prompt = "" | |
def clear_history(): | |
st.session_state.history.clear() | |
st.session_state.token_count = 0 | |
st.session_state.conversation.memory.clear() | |
load_css() | |
initialize_session_state() | |
exemples = [ | |
"Who has published influential research on quantum computing?", | |
"List any prominent authors in the field of artificial intelligence ethics?", | |
"Who are the leading experts on climate change mitigation strategies?", | |
] | |
with chat_column: | |
chat_placeholder = st.container() | |
prompt_placeholder = st.form("chat-form", clear_on_submit=True) | |
information_placeholder = st.container() | |
with chat_placeholder: | |
for chat in st.session_state.history: | |
div = f""" | |
<div class="chat-row | |
{'' if chat.origin == 'ai' else 'row-reverse'}"> | |
<img class="chat-icon" src="./app/static/{ | |
'ai_icon.png' if chat.origin == 'ai' | |
else 'user_icon.png'}" | |
width=32 height=32> | |
<div class="chat-bubble | |
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> | |
​{chat.message} | |
</div> | |
</div> | |
""" | |
st.markdown(div, unsafe_allow_html=True) | |
for _ in range(3): | |
st.markdown("") | |
with prompt_placeholder: | |
st.markdown("**Chat**") | |
cols = st.columns((6, 1)) | |
cols[0].text_input( | |
"Chat", | |
label_visibility="collapsed", | |
key="human_prompt", | |
) | |
cols[1].form_submit_button( | |
"Submit", | |
type="primary", | |
on_click=send_message_callback, | |
) | |
if st.session_state.token_count == 0: | |
information_placeholder.markdown("### Test me !") | |
for idx_exemple, exemple in enumerate(exemples): | |
information_placeholder.button( | |
exemple, | |
key=f"{idx_exemple}_button", | |
on_click=exemple_message_callback_button, | |
args=(exemple,) | |
) | |
st.button(":new: Start a new conversation", on_click=clear_history, type="secondary") | |
information_placeholder.caption( | |
f""" | |
Used {st.session_state.token_count} tokens \n | |
Debug Langchain conversation: | |
{st.session_state.history} | |
""" | |
) | |
components.html( | |
""" | |
<script> | |
const streamlitDoc = window.parent.document; | |
const buttons = Array.from( | |
streamlitDoc.querySelectorAll('.stButton > button') | |
); | |
const submitButton = buttons.find( | |
el => el.innerText === 'Submit' | |
); | |
streamlitDoc.addEventListener('keydown', function(e) { | |
switch (e.key) { | |
case 'Enter': | |
submitButton.click(); | |
break; | |
} | |
}); | |
</script> | |
""", | |
height=0, | |
width=0, | |
) | |
with doc_column: | |
st.markdown("**Source documents**") | |
if len(st.session_state.history) > 0: | |
for doc in st.session_state.history[-1].documents: | |
doc_content = json.loads(doc.page_content) | |
expander = st.expander(doc_content["title"]) | |
expander.markdown(f"**HalID** : https://hal.science/{doc_content['hal_id']}") | |
expander.markdown(doc_content["abstract"]) | |
expander.markdown(f"**Authors** : {doc_content['authors']}") | |
expander.markdown(f"**Keywords** : {doc_content['keywords']}") | |
expander.markdown(f"**Distance** : {doc_content['distance']}") | |