Léo Bourrel commited on
Commit
ac46aeb
·
1 Parent(s): b7f6a3a

feat: move chain to specific file

Browse files
Files changed (2) hide show
  1. sorbobotapp/app.py +2 -35
  2. sorbobotapp/chain.py +40 -0
sorbobotapp/app.py CHANGED
@@ -5,17 +5,11 @@ import streamlit as st
5
  import streamlit.components.v1 as components
6
  from langchain.callbacks import get_openai_callback
7
 
8
- from langchain.chains.conversation.memory import ConversationBufferMemory
9
- from langchain.embeddings import GPT4AllEmbeddings
10
- from langchain.llms import OpenAI
11
-
12
  from chat_history import insert_chat_history, insert_chat_history_articles
13
  from connection import connect
14
  from css import load_css
15
  from message import Message
16
- from vector_store import CustomVectorStore
17
- from conversation_retrieval_chain import CustomConversationalRetrievalChain
18
-
19
 
20
  st.set_page_config(layout="wide")
21
 
@@ -32,34 +26,7 @@ def initialize_session_state():
32
  if "token_count" not in st.session_state:
33
  st.session_state.token_count = 0
34
  if "conversation" not in st.session_state:
35
- embeddings = GPT4AllEmbeddings()
36
-
37
- db = CustomVectorStore(
38
- embedding_function=embeddings,
39
- table_name="article",
40
- column_name="abstract_embedding",
41
- connection=conn,
42
- )
43
-
44
- retriever = db.as_retriever()
45
-
46
- llm = OpenAI(
47
- temperature=0,
48
- openai_api_key=os.environ["OPENAI_API_KEY"],
49
- model="text-davinci-003",
50
- )
51
-
52
- memory = ConversationBufferMemory(
53
- output_key="answer", memory_key="chat_history", return_messages=True
54
- )
55
- st.session_state.conversation = CustomConversationalRetrievalChain.from_llm(
56
- llm=llm,
57
- retriever=retriever,
58
- verbose=True,
59
- memory=memory,
60
- return_source_documents=True,
61
- max_tokens_limit=3700,
62
- )
63
 
64
 
65
  def send_message_callback():
 
5
  import streamlit.components.v1 as components
6
  from langchain.callbacks import get_openai_callback
7
 
8
+ from chain import get_chain
 
 
 
9
  from chat_history import insert_chat_history, insert_chat_history_articles
10
  from connection import connect
11
  from css import load_css
12
  from message import Message
 
 
 
13
 
14
  st.set_page_config(layout="wide")
15
 
 
26
  if "token_count" not in st.session_state:
27
  st.session_state.token_count = 0
28
  if "conversation" not in st.session_state:
29
+ st.session_state.conversation = get_chain(conn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  def send_message_callback():
sorbobotapp/chain.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import sqlalchemy
4
+ from langchain.chains.conversation.memory import ConversationBufferMemory
5
+ from langchain.embeddings import GPT4AllEmbeddings
6
+ from langchain.llms import OpenAI
7
+
8
+ from conversation_retrieval_chain import CustomConversationalRetrievalChain
9
+ from vector_store import CustomVectorStore
10
+
11
+
12
+ def get_chain(conn: sqlalchemy.engine.Connection):
13
+ embeddings = GPT4AllEmbeddings()
14
+
15
+ db = CustomVectorStore(
16
+ embedding_function=embeddings,
17
+ table_name="article",
18
+ column_name="abstract_embedding",
19
+ connection=conn,
20
+ )
21
+
22
+ retriever = db.as_retriever()
23
+
24
+ llm = OpenAI(
25
+ temperature=0,
26
+ openai_api_key=os.environ["OPENAI_API_KEY"],
27
+ model="text-davinci-003",
28
+ )
29
+
30
+ memory = ConversationBufferMemory(
31
+ output_key="answer", memory_key="chat_history", return_messages=True
32
+ )
33
+ return CustomConversationalRetrievalChain.from_llm(
34
+ llm=llm,
35
+ retriever=retriever,
36
+ verbose=True,
37
+ memory=memory,
38
+ return_source_documents=True,
39
+ max_tokens_limit=3700,
40
+ )