Spaces:
Sleeping
Sleeping
Merge branch 'feat/limit_or_rephrase' into main
Browse files- app.py +5 -2
- conversation_retrieval_chain.py +64 -0
- models/distance.py +7 -0
- requirements.txt +5 -4
- vector_store.py +7 -4
app.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import streamlit as st
|
5 |
import streamlit.components.v1 as components
|
6 |
from langchain.callbacks import get_openai_callback
|
7 |
-
|
8 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
9 |
from langchain.embeddings import GPT4AllEmbeddings
|
10 |
from langchain.llms import OpenAI
|
@@ -14,6 +14,8 @@ from connection import connect
|
|
14 |
from css import load_css
|
15 |
from message import Message
|
16 |
from vector_store import CustomVectorStore
|
|
|
|
|
17 |
|
18 |
st.set_page_config(layout="wide")
|
19 |
|
@@ -50,12 +52,13 @@ def initialize_session_state():
|
|
50 |
memory = ConversationBufferMemory(
|
51 |
output_key="answer", memory_key="chat_history", return_messages=True
|
52 |
)
|
53 |
-
st.session_state.conversation =
|
54 |
llm=llm,
|
55 |
retriever=retriever,
|
56 |
verbose=True,
|
57 |
memory=memory,
|
58 |
return_source_documents=True,
|
|
|
59 |
)
|
60 |
|
61 |
|
|
|
4 |
import streamlit as st
|
5 |
import streamlit.components.v1 as components
|
6 |
from langchain.callbacks import get_openai_callback
|
7 |
+
|
8 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
9 |
from langchain.embeddings import GPT4AllEmbeddings
|
10 |
from langchain.llms import OpenAI
|
|
|
14 |
from css import load_css
|
15 |
from message import Message
|
16 |
from vector_store import CustomVectorStore
|
17 |
+
from conversation_retrieval_chain import CustomConversationalRetrievalChain
|
18 |
+
|
19 |
|
20 |
st.set_page_config(layout="wide")
|
21 |
|
|
|
52 |
memory = ConversationBufferMemory(
|
53 |
output_key="answer", memory_key="chat_history", return_messages=True
|
54 |
)
|
55 |
+
st.session_state.conversation = CustomConversationalRetrievalChain.from_llm(
|
56 |
llm=llm,
|
57 |
retriever=retriever,
|
58 |
verbose=True,
|
59 |
memory=memory,
|
60 |
return_source_documents=True,
|
61 |
+
max_tokens_limit=3700,
|
62 |
)
|
63 |
|
64 |
|
conversation_retrieval_chain.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
from langchain.chains.conversational_retrieval.base import (
|
5 |
+
ConversationalRetrievalChain,
|
6 |
+
_get_chat_history,
|
7 |
+
)
|
8 |
+
from langchain.callbacks.manager import CallbackManagerForChainRun
|
9 |
+
|
10 |
+
|
11 |
+
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
12 |
+
def _handle_docs(self, docs):
|
13 |
+
if len(docs) == 0:
|
14 |
+
return False, "No documents found. Can you rephrase ?"
|
15 |
+
elif len(docs) == 1:
|
16 |
+
return False, "Only one document found. Can you rephrase ?"
|
17 |
+
elif len(docs) > 10:
|
18 |
+
return False, "Too many documents found. Can you specify your request ?"
|
19 |
+
return True, ""
|
20 |
+
|
21 |
+
def _call(
|
22 |
+
self,
|
23 |
+
inputs: Dict[str, Any],
|
24 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
25 |
+
) -> Dict[str, Any]:
|
26 |
+
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
27 |
+
question = inputs["question"]
|
28 |
+
get_chat_history = self.get_chat_history or _get_chat_history
|
29 |
+
chat_history_str = get_chat_history(inputs["chat_history"])
|
30 |
+
|
31 |
+
if chat_history_str:
|
32 |
+
callbacks = _run_manager.get_child()
|
33 |
+
new_question = self.question_generator.run(
|
34 |
+
question=question, chat_history=chat_history_str, callbacks=callbacks
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
new_question = question
|
38 |
+
accepts_run_manager = (
|
39 |
+
"run_manager" in inspect.signature(self._get_docs).parameters
|
40 |
+
)
|
41 |
+
if accepts_run_manager:
|
42 |
+
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
43 |
+
else:
|
44 |
+
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
45 |
+
valid_docs, message = self._handle_docs(docs)
|
46 |
+
if not valid_docs:
|
47 |
+
return {
|
48 |
+
self.output_key: message,
|
49 |
+
"source_documents": docs,
|
50 |
+
}
|
51 |
+
|
52 |
+
new_inputs = inputs.copy()
|
53 |
+
if self.rephrase_question:
|
54 |
+
new_inputs["question"] = new_question
|
55 |
+
new_inputs["chat_history"] = chat_history_str
|
56 |
+
answer = self.combine_docs_chain.run(
|
57 |
+
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
58 |
+
)
|
59 |
+
output: Dict[str, Any] = {self.output_key: answer}
|
60 |
+
if self.return_source_documents:
|
61 |
+
output["source_documents"] = docs
|
62 |
+
if self.return_generated_question:
|
63 |
+
output["generated_question"] = new_question
|
64 |
+
return output
|
models/distance.py
CHANGED
@@ -1,6 +1,13 @@
|
|
1 |
import enum
|
2 |
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
class DistanceStrategy(str, enum.Enum):
|
5 |
"""Enumerator of the Distance strategies."""
|
6 |
|
|
|
1 |
import enum
|
2 |
|
3 |
|
4 |
+
distance_strategy_limit = {
|
5 |
+
"l2": 1.05,
|
6 |
+
"cosine": 0.55,
|
7 |
+
"inner": 1.0,
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
class DistanceStrategy(str, enum.Enum):
|
12 |
"""Enumerator of the Distance strategies."""
|
13 |
|
requirements.txt
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
gpt4all==1.0.12
|
2 |
langchain==0.0.313
|
3 |
openai==0.28.1
|
|
|
|
|
|
|
|
|
4 |
streamlit==1.27.2
|
5 |
streamlit-chat==0.1.1
|
6 |
SQLAlchemy==2.0.22
|
7 |
sqlite-vss==0.1.2
|
8 |
-
|
9 |
-
pgvector==0.2.3
|
10 |
-
psycopg2-binary==2.9.9
|
11 |
-
psycopg2==2.9.9
|
|
|
1 |
gpt4all==1.0.12
|
2 |
langchain==0.0.313
|
3 |
openai==0.28.1
|
4 |
+
pandas==2.1.1
|
5 |
+
pgvector==0.2.3
|
6 |
+
psycopg2-binary==2.9.9
|
7 |
+
psycopg2==2.9.9
|
8 |
streamlit==1.27.2
|
9 |
streamlit-chat==0.1.1
|
10 |
SQLAlchemy==2.0.22
|
11 |
sqlite-vss==0.1.2
|
12 |
+
tiktoken==0.5.1
|
|
|
|
|
|
vector_store.py
CHANGED
@@ -14,10 +14,10 @@ from sqlalchemy import delete, text
|
|
14 |
from sqlalchemy.orm import Session
|
15 |
|
16 |
from model import Article
|
17 |
-
from models.distance import DistanceStrategy
|
18 |
from utils import str_to_list
|
19 |
|
20 |
-
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.
|
21 |
|
22 |
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
23 |
|
@@ -252,6 +252,8 @@ class CustomVectorStore(VectorStore):
|
|
252 |
k: int = 4,
|
253 |
) -> List[Any]:
|
254 |
"""Query the collection."""
|
|
|
|
|
255 |
with Session(self._conn) as session:
|
256 |
results = session.execute(
|
257 |
text(
|
@@ -272,10 +274,11 @@ class CustomVectorStore(VectorStore):
|
|
272 |
left join author on author.id = article_author.author_id
|
273 |
where
|
274 |
abstract_en != '' and
|
275 |
-
abstract_en != 'None'
|
|
|
276 |
GROUP BY a.id
|
277 |
ORDER BY distance
|
278 |
-
LIMIT
|
279 |
"""
|
280 |
)
|
281 |
)
|
|
|
14 |
from sqlalchemy.orm import Session
|
15 |
|
16 |
from model import Article
|
17 |
+
from models.distance import DistanceStrategy, distance_strategy_limit
|
18 |
from utils import str_to_list
|
19 |
|
20 |
+
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
21 |
|
22 |
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
23 |
|
|
|
252 |
k: int = 4,
|
253 |
) -> List[Any]:
|
254 |
"""Query the collection."""
|
255 |
+
|
256 |
+
limit = distance_strategy_limit[self._distance_strategy]
|
257 |
with Session(self._conn) as session:
|
258 |
results = session.execute(
|
259 |
text(
|
|
|
274 |
left join author on author.id = article_author.author_id
|
275 |
where
|
276 |
abstract_en != '' and
|
277 |
+
abstract_en != 'None' and
|
278 |
+
abstract_embedding_en {self.distance_strategy} '{str(embedding)}' < {limit}
|
279 |
GROUP BY a.id
|
280 |
ORDER BY distance
|
281 |
+
LIMIT 100;
|
282 |
"""
|
283 |
)
|
284 |
)
|