leo-bourrel commited on
Commit
279f3c6
·
2 Parent(s): f419f72 8203352

Merge branch 'feat/limit_or_rephrase' into main

Browse files
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import streamlit as st
5
  import streamlit.components.v1 as components
6
  from langchain.callbacks import get_openai_callback
7
- from langchain.chains import ConversationalRetrievalChain
8
  from langchain.chains.conversation.memory import ConversationBufferMemory
9
  from langchain.embeddings import GPT4AllEmbeddings
10
  from langchain.llms import OpenAI
@@ -14,6 +14,8 @@ from connection import connect
14
  from css import load_css
15
  from message import Message
16
  from vector_store import CustomVectorStore
 
 
17
 
18
  st.set_page_config(layout="wide")
19
 
@@ -50,12 +52,13 @@ def initialize_session_state():
50
  memory = ConversationBufferMemory(
51
  output_key="answer", memory_key="chat_history", return_messages=True
52
  )
53
- st.session_state.conversation = ConversationalRetrievalChain.from_llm(
54
  llm=llm,
55
  retriever=retriever,
56
  verbose=True,
57
  memory=memory,
58
  return_source_documents=True,
 
59
  )
60
 
61
 
 
4
  import streamlit as st
5
  import streamlit.components.v1 as components
6
  from langchain.callbacks import get_openai_callback
7
+
8
  from langchain.chains.conversation.memory import ConversationBufferMemory
9
  from langchain.embeddings import GPT4AllEmbeddings
10
  from langchain.llms import OpenAI
 
14
  from css import load_css
15
  from message import Message
16
  from vector_store import CustomVectorStore
17
+ from conversation_retrieval_chain import CustomConversationalRetrievalChain
18
+
19
 
20
  st.set_page_config(layout="wide")
21
 
 
52
  memory = ConversationBufferMemory(
53
  output_key="answer", memory_key="chat_history", return_messages=True
54
  )
55
+ st.session_state.conversation = CustomConversationalRetrievalChain.from_llm(
56
  llm=llm,
57
  retriever=retriever,
58
  verbose=True,
59
  memory=memory,
60
  return_source_documents=True,
61
+ max_tokens_limit=3700,
62
  )
63
 
64
 
conversation_retrieval_chain.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Dict, Optional
3
+
4
+ from langchain.chains.conversational_retrieval.base import (
5
+ ConversationalRetrievalChain,
6
+ _get_chat_history,
7
+ )
8
+ from langchain.callbacks.manager import CallbackManagerForChainRun
9
+
10
+
11
+ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
12
+ def _handle_docs(self, docs):
13
+ if len(docs) == 0:
14
+ return False, "No documents found. Can you rephrase ?"
15
+ elif len(docs) == 1:
16
+ return False, "Only one document found. Can you rephrase ?"
17
+ elif len(docs) > 10:
18
+ return False, "Too many documents found. Can you specify your request ?"
19
+ return True, ""
20
+
21
+ def _call(
22
+ self,
23
+ inputs: Dict[str, Any],
24
+ run_manager: Optional[CallbackManagerForChainRun] = None,
25
+ ) -> Dict[str, Any]:
26
+ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
27
+ question = inputs["question"]
28
+ get_chat_history = self.get_chat_history or _get_chat_history
29
+ chat_history_str = get_chat_history(inputs["chat_history"])
30
+
31
+ if chat_history_str:
32
+ callbacks = _run_manager.get_child()
33
+ new_question = self.question_generator.run(
34
+ question=question, chat_history=chat_history_str, callbacks=callbacks
35
+ )
36
+ else:
37
+ new_question = question
38
+ accepts_run_manager = (
39
+ "run_manager" in inspect.signature(self._get_docs).parameters
40
+ )
41
+ if accepts_run_manager:
42
+ docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
43
+ else:
44
+ docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
45
+ valid_docs, message = self._handle_docs(docs)
46
+ if not valid_docs:
47
+ return {
48
+ self.output_key: message,
49
+ "source_documents": docs,
50
+ }
51
+
52
+ new_inputs = inputs.copy()
53
+ if self.rephrase_question:
54
+ new_inputs["question"] = new_question
55
+ new_inputs["chat_history"] = chat_history_str
56
+ answer = self.combine_docs_chain.run(
57
+ input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
58
+ )
59
+ output: Dict[str, Any] = {self.output_key: answer}
60
+ if self.return_source_documents:
61
+ output["source_documents"] = docs
62
+ if self.return_generated_question:
63
+ output["generated_question"] = new_question
64
+ return output
models/distance.py CHANGED
@@ -1,6 +1,13 @@
1
  import enum
2
 
3
 
 
 
 
 
 
 
 
4
  class DistanceStrategy(str, enum.Enum):
5
  """Enumerator of the Distance strategies."""
6
 
 
1
  import enum
2
 
3
 
4
+ distance_strategy_limit = {
5
+ "l2": 1.05,
6
+ "cosine": 0.55,
7
+ "inner": 1.0,
8
+ }
9
+
10
+
11
  class DistanceStrategy(str, enum.Enum):
12
  """Enumerator of the Distance strategies."""
13
 
requirements.txt CHANGED
@@ -1,11 +1,12 @@
1
  gpt4all==1.0.12
2
  langchain==0.0.313
3
  openai==0.28.1
 
 
 
 
4
  streamlit==1.27.2
5
  streamlit-chat==0.1.1
6
  SQLAlchemy==2.0.22
7
  sqlite-vss==0.1.2
8
- pandas==2.1.1
9
- pgvector==0.2.3
10
- psycopg2-binary==2.9.9
11
- psycopg2==2.9.9
 
1
  gpt4all==1.0.12
2
  langchain==0.0.313
3
  openai==0.28.1
4
+ pandas==2.1.1
5
+ pgvector==0.2.3
6
+ psycopg2-binary==2.9.9
7
+ psycopg2==2.9.9
8
  streamlit==1.27.2
9
  streamlit-chat==0.1.1
10
  SQLAlchemy==2.0.22
11
  sqlite-vss==0.1.2
12
+ tiktoken==0.5.1
 
 
 
vector_store.py CHANGED
@@ -14,10 +14,10 @@ from sqlalchemy import delete, text
14
  from sqlalchemy.orm import Session
15
 
16
  from model import Article
17
- from models.distance import DistanceStrategy
18
  from utils import str_to_list
19
 
20
- DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN
21
 
22
  _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
23
 
@@ -252,6 +252,8 @@ class CustomVectorStore(VectorStore):
252
  k: int = 4,
253
  ) -> List[Any]:
254
  """Query the collection."""
 
 
255
  with Session(self._conn) as session:
256
  results = session.execute(
257
  text(
@@ -272,10 +274,11 @@ class CustomVectorStore(VectorStore):
272
  left join author on author.id = article_author.author_id
273
  where
274
  abstract_en != '' and
275
- abstract_en != 'None'
 
276
  GROUP BY a.id
277
  ORDER BY distance
278
- LIMIT {k};
279
  """
280
  )
281
  )
 
14
  from sqlalchemy.orm import Session
15
 
16
  from model import Article
17
+ from models.distance import DistanceStrategy, distance_strategy_limit
18
  from utils import str_to_list
19
 
20
+ DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
21
 
22
  _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
23
 
 
252
  k: int = 4,
253
  ) -> List[Any]:
254
  """Query the collection."""
255
+
256
+ limit = distance_strategy_limit[self._distance_strategy]
257
  with Session(self._conn) as session:
258
  results = session.execute(
259
  text(
 
274
  left join author on author.id = article_author.author_id
275
  where
276
  abstract_en != '' and
277
+ abstract_en != 'None' and
278
+ abstract_embedding_en {self.distance_strategy} '{str(embedding)}' < {limit}
279
  GROUP BY a.id
280
  ORDER BY distance
281
+ LIMIT 100;
282
  """
283
  )
284
  )