Spaces:
Sleeping
Sleeping
Merge branch 'feat/extract_keyword' into main
Browse files- sorbobotapp/app.py +1 -2
- sorbobotapp/chain.py +1 -2
- sorbobotapp/conversation_retrieval_chain.py +36 -0
- sorbobotapp/keyword_extraction.py +58 -0
- sorbobotapp/vector_store.py +3 -4
sorbobotapp/app.py
CHANGED
@@ -3,12 +3,11 @@ import os
|
|
3 |
|
4 |
import streamlit as st
|
5 |
import streamlit.components.v1 as components
|
6 |
-
from langchain.callbacks import get_openai_callback
|
7 |
-
|
8 |
from chain import get_chain
|
9 |
from chat_history import insert_chat_history, insert_chat_history_articles
|
10 |
from connection import connect
|
11 |
from css import load_css
|
|
|
12 |
from message import Message
|
13 |
|
14 |
st.set_page_config(layout="wide")
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
import streamlit.components.v1 as components
|
|
|
|
|
6 |
from chain import get_chain
|
7 |
from chat_history import insert_chat_history, insert_chat_history_articles
|
8 |
from connection import connect
|
9 |
from css import load_css
|
10 |
+
from langchain.callbacks import get_openai_callback
|
11 |
from message import Message
|
12 |
|
13 |
st.set_page_config(layout="wide")
|
sorbobotapp/chain.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
import os
|
2 |
|
3 |
import sqlalchemy
|
|
|
4 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
5 |
from langchain.embeddings import GPT4AllEmbeddings
|
6 |
from langchain.llms import OpenAI
|
7 |
-
|
8 |
-
from conversation_retrieval_chain import CustomConversationalRetrievalChain
|
9 |
from vector_store import CustomVectorStore
|
10 |
|
11 |
|
|
|
1 |
import os
|
2 |
|
3 |
import sqlalchemy
|
4 |
+
from conversation_retrieval_chain import CustomConversationalRetrievalChain
|
5 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
6 |
from langchain.embeddings import GPT4AllEmbeddings
|
7 |
from langchain.llms import OpenAI
|
|
|
|
|
8 |
from vector_store import CustomVectorStore
|
9 |
|
10 |
|
sorbobotapp/conversation_retrieval_chain.py
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
import inspect
|
|
|
2 |
from typing import Any, Dict, Optional
|
3 |
|
|
|
4 |
from langchain.callbacks.manager import CallbackManagerForChainRun
|
5 |
from langchain.chains.conversational_retrieval.base import (
|
6 |
ConversationalRetrievalChain, _get_chat_history)
|
|
|
7 |
|
8 |
|
9 |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
|
|
|
10 |
def _handle_docs(self, docs):
|
11 |
if len(docs) == 0:
|
12 |
return False, "No documents found. Can you rephrase ?"
|
@@ -16,6 +21,33 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
16 |
return False, "Too many documents found. Can you specify your request ?"
|
17 |
return True, ""
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def _call(
|
20 |
self,
|
21 |
inputs: Dict[str, Any],
|
@@ -40,6 +72,7 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
40 |
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
41 |
else:
|
42 |
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
|
|
43 |
valid_docs, message = self._handle_docs(docs)
|
44 |
if not valid_docs:
|
45 |
return {
|
@@ -47,6 +80,9 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
47 |
"source_documents": docs,
|
48 |
}
|
49 |
|
|
|
|
|
|
|
50 |
new_inputs = inputs.copy()
|
51 |
if self.rephrase_question:
|
52 |
new_inputs["question"] = new_question
|
|
|
1 |
import inspect
|
2 |
+
import json
|
3 |
from typing import Any, Dict, Optional
|
4 |
|
5 |
+
from keyword_extraction import KeywordExtractor
|
6 |
from langchain.callbacks.manager import CallbackManagerForChainRun
|
7 |
from langchain.chains.conversational_retrieval.base import (
|
8 |
ConversationalRetrievalChain, _get_chat_history)
|
9 |
+
from langchain.schema import Document
|
10 |
|
11 |
|
12 |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
13 |
+
keyword_extractor: KeywordExtractor = KeywordExtractor()
|
14 |
+
|
15 |
def _handle_docs(self, docs):
|
16 |
if len(docs) == 0:
|
17 |
return False, "No documents found. Can you rephrase ?"
|
|
|
21 |
return False, "Too many documents found. Can you specify your request ?"
|
22 |
return True, ""
|
23 |
|
24 |
+
def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]:
|
25 |
+
"""Rerank documents based on the number of similar keywords
|
26 |
+
|
27 |
+
Args:
|
28 |
+
question (str): Orinal question
|
29 |
+
docs (list[Document]): List of documents
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
list[Document]: List of documents sorted by the number of similar keywords
|
33 |
+
"""
|
34 |
+
keywords = self.keyword_extractor(question)
|
35 |
+
|
36 |
+
for doc in docs:
|
37 |
+
doc.metadata["similar_keyword"] = 0
|
38 |
+
doc_keywords = json.loads(doc.page_content)["keywords"]
|
39 |
+
if doc_keywords is None:
|
40 |
+
continue
|
41 |
+
doc_keywords = doc_keywords.lower().split(",")
|
42 |
+
|
43 |
+
for kw in keywords:
|
44 |
+
if kw.lower() in doc_keywords:
|
45 |
+
doc.metadata["similar_keyword"] += 1
|
46 |
+
print("similar keyword : ", kw)
|
47 |
+
|
48 |
+
docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"])
|
49 |
+
return docs
|
50 |
+
|
51 |
def _call(
|
52 |
self,
|
53 |
inputs: Dict[str, Any],
|
|
|
72 |
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
73 |
else:
|
74 |
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
75 |
+
|
76 |
valid_docs, message = self._handle_docs(docs)
|
77 |
if not valid_docs:
|
78 |
return {
|
|
|
80 |
"source_documents": docs,
|
81 |
}
|
82 |
|
83 |
+
# Add reranking
|
84 |
+
docs = self.rerank_documents(new_question, docs)
|
85 |
+
|
86 |
new_inputs = inputs.copy()
|
87 |
if self.rephrase_question:
|
88 |
new_inputs["question"] = new_question
|
sorbobotapp/keyword_extraction.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from langchain.chat_models import ChatOpenAI
|
4 |
+
from langchain.output_parsers import NumberedListOutputParser
|
5 |
+
from langchain.prompts import ChatPromptTemplate
|
6 |
+
from utils import str_to_list
|
7 |
+
|
8 |
+
query_template = """
|
9 |
+
You are a bi-lingual (french and english) linguistic teacher working at a top-tier university.
|
10 |
+
We are conducting a research project that requires the extraction of keywords from chatbot queries.
|
11 |
+
Below, you will find a query. Please identify and rank the three most important keywords or phrases (n-grams) based on their relevance to the main topic of the query.
|
12 |
+
For each keyword or phrase, assign it to one of the following categories: ["University / Company", "Research domain", "Country", "Name", "Other"].
|
13 |
+
An 'n-gram' refers to a contiguous sequence of words, where 'n' can be 1 for a single word, 2 for a pair of words, and so on, up to two words in length.
|
14 |
+
Please ensure not to list more than three n-grams in total.
|
15 |
+
Your expertise in linguistic analysis is crucial for the success of this project. Thank you for your contribution.
|
16 |
+
|
17 |
+
Please attach your ranked list in the following format:
|
18 |
+
1. Keyword/Phrase - Category
|
19 |
+
2. Keyword/Phrase - Category
|
20 |
+
3. Keyword/Phrase - Category
|
21 |
+
|
22 |
+
You must be concise and don't need to justify your choices.
|
23 |
+
```
|
24 |
+
{query}
|
25 |
+
```
|
26 |
+
"""
|
27 |
+
|
28 |
+
output_parser = NumberedListOutputParser()
|
29 |
+
format_instructions = output_parser.get_format_instructions()
|
30 |
+
|
31 |
+
|
32 |
+
class KeywordExtractor:
|
33 |
+
def __init__(self):
|
34 |
+
super().__init__()
|
35 |
+
self.model = ChatOpenAI()
|
36 |
+
self.prompt = ChatPromptTemplate.from_template(
|
37 |
+
template=query_template,
|
38 |
+
)
|
39 |
+
|
40 |
+
self.chain = self.prompt | self.model # | output_parser
|
41 |
+
|
42 |
+
def __call__(
|
43 |
+
self, inputs: str, filter_categories: list[str] = ["Research domain"]
|
44 |
+
) -> Any:
|
45 |
+
output = self.chain.invoke({"query": inputs})
|
46 |
+
|
47 |
+
keywords = output_parser.parse(output.content)
|
48 |
+
|
49 |
+
filtered_keywords = []
|
50 |
+
for keyword in keywords:
|
51 |
+
if " - " not in keyword:
|
52 |
+
continue
|
53 |
+
|
54 |
+
keyword, category = keyword.split(" - ", maxsplit=2)
|
55 |
+
if category in filter_categories:
|
56 |
+
filtered_keywords.append(keyword)
|
57 |
+
|
58 |
+
return filtered_keywords
|
sorbobotapp/vector_store.py
CHANGED
@@ -10,11 +10,10 @@ import sqlalchemy
|
|
10 |
from langchain.docstore.document import Document
|
11 |
from langchain.schema.embeddings import Embeddings
|
12 |
from langchain.vectorstores.base import VectorStore
|
13 |
-
from sqlalchemy import delete, text
|
14 |
-
from sqlalchemy.orm import Session
|
15 |
-
|
16 |
from models.article import Article
|
17 |
from models.distance import DistanceStrategy, distance_strategy_limit
|
|
|
|
|
18 |
from utils import str_to_list
|
19 |
|
20 |
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
@@ -245,7 +244,7 @@ class CustomVectorStore(VectorStore):
|
|
245 |
"doi": result["doi"],
|
246 |
"hal_id": result["hal_id"],
|
247 |
"distance": result["distance"],
|
248 |
-
"abstract": result["abstract"],
|
249 |
},
|
250 |
),
|
251 |
result["distance"] if self.embedding_function is not None else None,
|
|
|
10 |
from langchain.docstore.document import Document
|
11 |
from langchain.schema.embeddings import Embeddings
|
12 |
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
|
13 |
from models.article import Article
|
14 |
from models.distance import DistanceStrategy, distance_strategy_limit
|
15 |
+
from sqlalchemy import delete, text
|
16 |
+
from sqlalchemy.orm import Session
|
17 |
from utils import str_to_list
|
18 |
|
19 |
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
|
|
244 |
"doi": result["doi"],
|
245 |
"hal_id": result["hal_id"],
|
246 |
"distance": result["distance"],
|
247 |
+
"abstract": result["abstract"][0],
|
248 |
},
|
249 |
),
|
250 |
result["distance"] if self.embedding_function is not None else None,
|