leo-bourrel commited on
Commit
c6332f2
·
2 Parent(s): 28a498b dc294ab

Merge branch 'feat/extract_keyword' into main

Browse files
sorbobotapp/app.py CHANGED
@@ -3,12 +3,11 @@ import os
3
 
4
  import streamlit as st
5
  import streamlit.components.v1 as components
6
- from langchain.callbacks import get_openai_callback
7
-
8
  from chain import get_chain
9
  from chat_history import insert_chat_history, insert_chat_history_articles
10
  from connection import connect
11
  from css import load_css
 
12
  from message import Message
13
 
14
  st.set_page_config(layout="wide")
 
3
 
4
  import streamlit as st
5
  import streamlit.components.v1 as components
 
 
6
  from chain import get_chain
7
  from chat_history import insert_chat_history, insert_chat_history_articles
8
  from connection import connect
9
  from css import load_css
10
+ from langchain.callbacks import get_openai_callback
11
  from message import Message
12
 
13
  st.set_page_config(layout="wide")
sorbobotapp/chain.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
 
3
  import sqlalchemy
 
4
  from langchain.chains.conversation.memory import ConversationBufferMemory
5
  from langchain.embeddings import GPT4AllEmbeddings
6
  from langchain.llms import OpenAI
7
-
8
- from conversation_retrieval_chain import CustomConversationalRetrievalChain
9
  from vector_store import CustomVectorStore
10
 
11
 
 
1
  import os
2
 
3
  import sqlalchemy
4
+ from conversation_retrieval_chain import CustomConversationalRetrievalChain
5
  from langchain.chains.conversation.memory import ConversationBufferMemory
6
  from langchain.embeddings import GPT4AllEmbeddings
7
  from langchain.llms import OpenAI
 
 
8
  from vector_store import CustomVectorStore
9
 
10
 
sorbobotapp/conversation_retrieval_chain.py CHANGED
@@ -1,12 +1,17 @@
1
  import inspect
 
2
  from typing import Any, Dict, Optional
3
 
 
4
  from langchain.callbacks.manager import CallbackManagerForChainRun
5
  from langchain.chains.conversational_retrieval.base import (
6
  ConversationalRetrievalChain, _get_chat_history)
 
7
 
8
 
9
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
 
 
10
  def _handle_docs(self, docs):
11
  if len(docs) == 0:
12
  return False, "No documents found. Can you rephrase ?"
@@ -16,6 +21,33 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
16
  return False, "Too many documents found. Can you specify your request ?"
17
  return True, ""
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def _call(
20
  self,
21
  inputs: Dict[str, Any],
@@ -40,6 +72,7 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
40
  docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
41
  else:
42
  docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
 
43
  valid_docs, message = self._handle_docs(docs)
44
  if not valid_docs:
45
  return {
@@ -47,6 +80,9 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
47
  "source_documents": docs,
48
  }
49
 
 
 
 
50
  new_inputs = inputs.copy()
51
  if self.rephrase_question:
52
  new_inputs["question"] = new_question
 
1
  import inspect
2
+ import json
3
  from typing import Any, Dict, Optional
4
 
5
+ from keyword_extraction import KeywordExtractor
6
  from langchain.callbacks.manager import CallbackManagerForChainRun
7
  from langchain.chains.conversational_retrieval.base import (
8
  ConversationalRetrievalChain, _get_chat_history)
9
+ from langchain.schema import Document
10
 
11
 
12
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
13
+ keyword_extractor: KeywordExtractor = KeywordExtractor()
14
+
15
  def _handle_docs(self, docs):
16
  if len(docs) == 0:
17
  return False, "No documents found. Can you rephrase ?"
 
21
  return False, "Too many documents found. Can you specify your request ?"
22
  return True, ""
23
 
24
+ def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]:
25
+ """Rerank documents based on the number of similar keywords
26
+
27
+ Args:
28
+ question (str): Orinal question
29
+ docs (list[Document]): List of documents
30
+
31
+ Returns:
32
+ list[Document]: List of documents sorted by the number of similar keywords
33
+ """
34
+ keywords = self.keyword_extractor(question)
35
+
36
+ for doc in docs:
37
+ doc.metadata["similar_keyword"] = 0
38
+ doc_keywords = json.loads(doc.page_content)["keywords"]
39
+ if doc_keywords is None:
40
+ continue
41
+ doc_keywords = doc_keywords.lower().split(",")
42
+
43
+ for kw in keywords:
44
+ if kw.lower() in doc_keywords:
45
+ doc.metadata["similar_keyword"] += 1
46
+ print("similar keyword : ", kw)
47
+
48
+ docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"])
49
+ return docs
50
+
51
  def _call(
52
  self,
53
  inputs: Dict[str, Any],
 
72
  docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
73
  else:
74
  docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
75
+
76
  valid_docs, message = self._handle_docs(docs)
77
  if not valid_docs:
78
  return {
 
80
  "source_documents": docs,
81
  }
82
 
83
+ # Add reranking
84
+ docs = self.rerank_documents(new_question, docs)
85
+
86
  new_inputs = inputs.copy()
87
  if self.rephrase_question:
88
  new_inputs["question"] = new_question
sorbobotapp/keyword_extraction.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain.output_parsers import NumberedListOutputParser
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from utils import str_to_list
7
+
8
+ query_template = """
9
+ You are a bi-lingual (french and english) linguistic teacher working at a top-tier university.
10
+ We are conducting a research project that requires the extraction of keywords from chatbot queries.
11
+ Below, you will find a query. Please identify and rank the three most important keywords or phrases (n-grams) based on their relevance to the main topic of the query.
12
+ For each keyword or phrase, assign it to one of the following categories: ["University / Company", "Research domain", "Country", "Name", "Other"].
13
+ An 'n-gram' refers to a contiguous sequence of words, where 'n' can be 1 for a single word, 2 for a pair of words, and so on, up to two words in length.
14
+ Please ensure not to list more than three n-grams in total.
15
+ Your expertise in linguistic analysis is crucial for the success of this project. Thank you for your contribution.
16
+
17
+ Please attach your ranked list in the following format:
18
+ 1. Keyword/Phrase - Category
19
+ 2. Keyword/Phrase - Category
20
+ 3. Keyword/Phrase - Category
21
+
22
+ You must be concise and don't need to justify your choices.
23
+ ```
24
+ {query}
25
+ ```
26
+ """
27
+
28
+ output_parser = NumberedListOutputParser()
29
+ format_instructions = output_parser.get_format_instructions()
30
+
31
+
32
+ class KeywordExtractor:
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.model = ChatOpenAI()
36
+ self.prompt = ChatPromptTemplate.from_template(
37
+ template=query_template,
38
+ )
39
+
40
+ self.chain = self.prompt | self.model # | output_parser
41
+
42
+ def __call__(
43
+ self, inputs: str, filter_categories: list[str] = ["Research domain"]
44
+ ) -> Any:
45
+ output = self.chain.invoke({"query": inputs})
46
+
47
+ keywords = output_parser.parse(output.content)
48
+
49
+ filtered_keywords = []
50
+ for keyword in keywords:
51
+ if " - " not in keyword:
52
+ continue
53
+
54
+ keyword, category = keyword.split(" - ", maxsplit=2)
55
+ if category in filter_categories:
56
+ filtered_keywords.append(keyword)
57
+
58
+ return filtered_keywords
sorbobotapp/vector_store.py CHANGED
@@ -10,11 +10,10 @@ import sqlalchemy
10
  from langchain.docstore.document import Document
11
  from langchain.schema.embeddings import Embeddings
12
  from langchain.vectorstores.base import VectorStore
13
- from sqlalchemy import delete, text
14
- from sqlalchemy.orm import Session
15
-
16
  from models.article import Article
17
  from models.distance import DistanceStrategy, distance_strategy_limit
 
 
18
  from utils import str_to_list
19
 
20
  DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
@@ -245,7 +244,7 @@ class CustomVectorStore(VectorStore):
245
  "doi": result["doi"],
246
  "hal_id": result["hal_id"],
247
  "distance": result["distance"],
248
- "abstract": result["abstract"],
249
  },
250
  ),
251
  result["distance"] if self.embedding_function is not None else None,
 
10
  from langchain.docstore.document import Document
11
  from langchain.schema.embeddings import Embeddings
12
  from langchain.vectorstores.base import VectorStore
 
 
 
13
  from models.article import Article
14
  from models.distance import DistanceStrategy, distance_strategy_limit
15
+ from sqlalchemy import delete, text
16
+ from sqlalchemy.orm import Session
17
  from utils import str_to_list
18
 
19
  DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
 
244
  "doi": result["doi"],
245
  "hal_id": result["hal_id"],
246
  "distance": result["distance"],
247
+ "abstract": result["abstract"][0],
248
  },
249
  ),
250
  result["distance"] if self.embedding_function is not None else None,