Spaces:
Sleeping
Sleeping
Saiteja Solleti
commited on
Commit
·
9a8353d
1
Parent(s):
e8e78ae
milvas insert and search addition
Browse files- app.py +25 -2
- createmilvusschema.py +0 -1
- crudmilvus.py +0 -13
- insertmilvushelper.py +143 -0
- loaddataset.py +1 -1
- requirements.txt +2 -1
- searchmilvushelper.py +52 -0
app.py
CHANGED
@@ -3,16 +3,28 @@ import os
|
|
3 |
|
4 |
from loaddataset import ExtractRagBenchData
|
5 |
from createmilvusschema import CreateMilvusDbSchema
|
|
|
|
|
|
|
|
|
6 |
from model import generate_response
|
7 |
from huggingface_hub import login
|
8 |
from huggingface_hub import whoami
|
9 |
from huggingface_hub import dataset_info
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
hf_token = os.getenv("HF_TOKEN")
|
13 |
login(hf_token)
|
14 |
|
15 |
rag_extracted_data = ExtractRagBenchData()
|
|
|
16 |
|
17 |
#invoke create milvus db function
|
18 |
try:
|
@@ -20,12 +32,23 @@ try:
|
|
20 |
except Exception as e:
|
21 |
print(f"Error creating Milvus DB schema: {e}")
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def chatbot(prompt):
|
26 |
return whoami()
|
27 |
|
28 |
-
iface = gr.Interface(fn=chatbot,
|
|
|
|
|
|
|
29 |
|
30 |
if __name__ == "__main__":
|
31 |
iface.launch()
|
|
|
3 |
|
4 |
from loaddataset import ExtractRagBenchData
|
5 |
from createmilvusschema import CreateMilvusDbSchema
|
6 |
+
from crudmilvushelper import EmbedAllDocumentsAndInsert
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
from searchmilvushelper import SearchTopKDocuments
|
9 |
+
|
10 |
from model import generate_response
|
11 |
from huggingface_hub import login
|
12 |
from huggingface_hub import whoami
|
13 |
from huggingface_hub import dataset_info
|
14 |
|
15 |
|
16 |
+
# Load embedding model
|
17 |
+
QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
|
18 |
+
WINDOW_SIZE = 5
|
19 |
+
OVERLAP = 2
|
20 |
+
RETRIVE_TOP_K_SIZE=10
|
21 |
+
|
22 |
+
|
23 |
hf_token = os.getenv("HF_TOKEN")
|
24 |
login(hf_token)
|
25 |
|
26 |
rag_extracted_data = ExtractRagBenchData()
|
27 |
+
print(rag_extracted_data.head(5))
|
28 |
|
29 |
#invoke create milvus db function
|
30 |
try:
|
|
|
32 |
except Exception as e:
|
33 |
print(f"Error creating Milvus DB schema: {e}")
|
34 |
|
35 |
+
#insert embdeding to milvus db
|
36 |
+
"""
|
37 |
+
EmbedAllDocumentsAndInsert(QUERY_EMBEDDING_MODEL, rag_extracted_data, db_collection, window_size=WINDOW_SIZE, overlap=OVERLAP)
|
38 |
+
"""
|
39 |
+
query = "what would the net revenue have been in 2015 if there wasn't a stipulated settlement from the business combination in october 2015?"
|
40 |
+
|
41 |
+
results_for_top5_chunks = SearchTopKDocuments(db_collection, query, QUERY_EMBEDDING_MODEL, top_k=RETRIVE_TOP_K_SIZE)
|
42 |
+
print(results_for_top5_chunks)
|
43 |
+
|
44 |
|
45 |
def chatbot(prompt):
|
46 |
return whoami()
|
47 |
|
48 |
+
iface = gr.Interface(fn=chatbot,
|
49 |
+
inputs="text",
|
50 |
+
outputs="text",
|
51 |
+
title="Capstone Project Group 10")
|
52 |
|
53 |
if __name__ == "__main__":
|
54 |
iface.launch()
|
createmilvusschema.py
CHANGED
@@ -5,7 +5,6 @@ milvus_token = os.getenv("MILVUS_TOKEN")
|
|
5 |
|
6 |
COLLECTION_NAME = "final_ragbench_document_embeddings"
|
7 |
MILVUS_CLOUD_URI = "https://in03-7b4da1b7b588a88.serverless.gcp-us-west1.cloud.zilliz.com"
|
8 |
-
connections.connect("default", uri=MILVUS_CLOUD_URI, token=milvus_token)
|
9 |
|
10 |
#Function to create milvus db schema to insert the data
|
11 |
def CreateMilvusDbSchema():
|
|
|
5 |
|
6 |
COLLECTION_NAME = "final_ragbench_document_embeddings"
|
7 |
MILVUS_CLOUD_URI = "https://in03-7b4da1b7b588a88.serverless.gcp-us-west1.cloud.zilliz.com"
|
|
|
8 |
|
9 |
#Function to create milvus db schema to insert the data
|
10 |
def CreateMilvusDbSchema():
|
crudmilvus.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection
|
3 |
-
from sentence_transformers import SentenceTransformer
|
4 |
-
|
5 |
-
milvus_token = os.getenv("MILVUS_TOKEN")
|
6 |
-
|
7 |
-
COLLECTION_NAME = "final_ragbench_document_embeddings"
|
8 |
-
MILVUS_CLOUD_URI = "https://in03-7b4da1b7b588a88.serverless.gcp-us-west1.cloud.zilliz.com"
|
9 |
-
connections.connect("default", uri=MILVUS_CLOUD_URI, token=milvus_token)
|
10 |
-
|
11 |
-
# Verify connection
|
12 |
-
print(connections.get_connection_addr("default"))
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
insertmilvushelper.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
nltk.data.path.append("/content/nltk_data")
|
5 |
+
nltk.download('punkt')
|
6 |
+
nltk.download('wordnet')
|
7 |
+
nltk.download('omw-1.4')
|
8 |
+
nltk.download('punkt_tab')
|
9 |
+
|
10 |
+
from nltk.tokenize import sent_tokenize
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
#Splits a list of sentences into overlapping chunks using a sliding window approach.
|
17 |
+
#sentences (list): List of sentences to split into chunks.
|
18 |
+
# window_size (int): Number of sentences in each chunk. Default is 6.
|
19 |
+
# overlap (int): Number of overlapping sentences between consecutive chunks. Default is 3.
|
20 |
+
# Returns:
|
21 |
+
# list: List of text chunks, where each chunk is a string of concatenated sentences.
|
22 |
+
|
23 |
+
def split_into_sliding_windows(sentences, window_size=6, overlap=3):
|
24 |
+
|
25 |
+
# Validate input parameters
|
26 |
+
if window_size <= overlap:
|
27 |
+
raise ValueError("window_size must be greater than overlap.")
|
28 |
+
if not sentences:
|
29 |
+
return []
|
30 |
+
|
31 |
+
chunks = []
|
32 |
+
step = window_size - overlap # How much to move the window each time
|
33 |
+
|
34 |
+
# Iterate over the sentences with the specified step size
|
35 |
+
for i in range(0, len(sentences), step):
|
36 |
+
chunk = sentences[i:i + window_size]
|
37 |
+
if len(chunk) >= overlap: # Ensure chunks have minimum required overlap
|
38 |
+
chunks.append(" ".join(chunk)) # Join sentences into a text block
|
39 |
+
|
40 |
+
return chunks
|
41 |
+
|
42 |
+
# Processes documents using a sliding window approach and inserts sentence chunks into Milvus.
|
43 |
+
#Args: model: The embedding model used to generate document embeddings.
|
44 |
+
# extracted_data: Pandas DataFrame containing the extracted data.
|
45 |
+
# collectionInstance: Milvus collection instance to insert data into.
|
46 |
+
# window_size: Number of sentences in each chunk.
|
47 |
+
# overlap: Number of overlapping sentences between consecutive chunks.
|
48 |
+
#
|
49 |
+
|
50 |
+
def EmbedAllDocumentsAndInsert(model, extracted_data, collectionInstance, window_size=5, overlap=2):
|
51 |
+
|
52 |
+
count = 0
|
53 |
+
total_docs = len(extracted_data)
|
54 |
+
print(f"Total documents: {total_docs}")
|
55 |
+
|
56 |
+
for index, row in extracted_data.iterrows():
|
57 |
+
document = row["documents"] # Extract the document text
|
58 |
+
doc_id = row["id"] # Extract the document ID
|
59 |
+
doccontextrel = row["gpt3_context_relevance"] # Extract context relevance score
|
60 |
+
doccontextutil = row["gpt35_utilization"] # Extract context utilization score
|
61 |
+
docadherence = row["gpt3_adherence"] # Extract adherence score
|
62 |
+
datasetname = row["dataset_name"] # Extract dataset name
|
63 |
+
relevance_score = row["relevance_score"] # Extract relevance score
|
64 |
+
utilization_score = row["utilization_score"] # Extract utilization score
|
65 |
+
completeness_score = row["completeness_score"] # Extract completeness score
|
66 |
+
|
67 |
+
|
68 |
+
if isinstance(document, list):
|
69 |
+
# Flatten the list into a single string
|
70 |
+
document = " ".join([str(item) for item in document if isinstance(item, str)])
|
71 |
+
elif not isinstance(document, str):
|
72 |
+
# If the document is not a string or list, convert it to a string
|
73 |
+
document = str(document)
|
74 |
+
|
75 |
+
# Step 1: Tokenize document into sentences
|
76 |
+
sentences = sent_tokenize(document) if isinstance(document, str) else document
|
77 |
+
|
78 |
+
# Step 2: Generate overlapping chunks
|
79 |
+
chunks = split_into_sliding_windows(sentences, window_size, overlap)
|
80 |
+
|
81 |
+
print(f"Total chunks for document {index}: {len(chunks)}")
|
82 |
+
|
83 |
+
for chunk_index, chunk_text in enumerate(chunks):
|
84 |
+
# Step 3: Generate embedding for each chunk
|
85 |
+
chunk_vector = np.array(model.encode(chunk_text), dtype=np.float32).flatten().tolist()
|
86 |
+
|
87 |
+
print(f"chunk_index= {chunk_index}")
|
88 |
+
|
89 |
+
# Step 4: Insert chunk into Milvus as separate columns
|
90 |
+
insert_embeddings_into_milvus(
|
91 |
+
collectionInstance,
|
92 |
+
chunk_vector,
|
93 |
+
f"{chunk_index}__{doc_id}", # Unique ID for chunk
|
94 |
+
doc_id, # Unique ID for doc
|
95 |
+
index,
|
96 |
+
float(doccontextrel) if pd.notna(doccontextrel) else 0.0, # Handle NaN values
|
97 |
+
float(doccontextutil) if pd.notna(doccontextutil) else 0.0, # Handle NaN values
|
98 |
+
float(docadherence) if pd.notna(docadherence) else 0.0, # Handle NaN values
|
99 |
+
datasetname, # Dataset name column
|
100 |
+
float(relevance_score) if pd.notna(relevance_score) else 0.0, # Handle NaN values
|
101 |
+
float(utilization_score) if pd.notna(utilization_score) else 0.0, # Handle NaN values
|
102 |
+
float(completeness_score) if pd.notna(completeness_score) else 0.0 # Handle NaN values
|
103 |
+
)
|
104 |
+
|
105 |
+
count += 1
|
106 |
+
if count % 1000 == 0:
|
107 |
+
print(f"Uploaded {count} chunks to Milvus.")
|
108 |
+
|
109 |
+
# Inserts document embeddings into Milvus along with metadata.
|
110 |
+
#Args:
|
111 |
+
# collection: Milvus collection instance.
|
112 |
+
# embeddings: Embedding vector for the chunk.
|
113 |
+
# chunk_doc_id: Unique ID for the chunk.
|
114 |
+
# doc_id: Unique ID for the document.
|
115 |
+
# index: Index of the document in the dataset.
|
116 |
+
# doccontextrel: Context relevance score.
|
117 |
+
# doccontextutil: Context utilization score.
|
118 |
+
# docadherence: Adherence score.
|
119 |
+
# datasetname: Name of the dataset.
|
120 |
+
|
121 |
+
def insert_embeddings_into_milvus(collection, embeddings, chunk_doc_id, doc_id, index,
|
122 |
+
doccontextrel, doccontextutil, docadherence, datasetname,
|
123 |
+
relevance_score, utilization_score, completeness_score):
|
124 |
+
|
125 |
+
try:
|
126 |
+
print(f"Inserting chunk {chunk_doc_id} doc {doc_id} (index {index})")
|
127 |
+
insert_data = [
|
128 |
+
[str(chunk_doc_id)], # Primary key field (document_id)
|
129 |
+
[str(doc_id)], # Document ID field
|
130 |
+
[embeddings], # Vector field (embedding)
|
131 |
+
[float(doccontextrel)], # Relevance score field
|
132 |
+
[float(doccontextutil)], # Utilization score field
|
133 |
+
[float(docadherence)], # Adherence score field
|
134 |
+
[str(datasetname)], # Dataset name field
|
135 |
+
[float(relevance_score)], # Relevance score field
|
136 |
+
[float(utilization_score)], # Utilization score field
|
137 |
+
[float(completeness_score)] # Completeness score field
|
138 |
+
]
|
139 |
+
collection.insert(insert_data)
|
140 |
+
except Exception as e:
|
141 |
+
print(f"Error inserting chunk {chunk_doc_id} doc {doc_id} (index {index}): {e}")
|
142 |
+
|
143 |
+
|
loaddataset.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import pandas as pd
|
2 |
from datasets import load_dataset
|
3 |
from logger import logger
|
4 |
-
from typing import Dict, List
|
5 |
|
6 |
|
7 |
DATASET_CONFIGS = [
|
|
|
1 |
import pandas as pd
|
2 |
from datasets import load_dataset
|
3 |
from logger import logger
|
4 |
+
from typing import Dict, List
|
5 |
|
6 |
|
7 |
DATASET_CONFIGS = [
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ gradio
|
|
2 |
transformers
|
3 |
torch
|
4 |
huggingface_hub
|
5 |
-
pymilvus
|
|
|
|
2 |
transformers
|
3 |
torch
|
4 |
huggingface_hub
|
5 |
+
pymilvus
|
6 |
+
nltk
|
searchmilvushelper.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Search Milvus by generating an embedding for the query text. Returns the top_k most similar documents.
|
2 |
+
#Retrieves all columns defined in the Milvus schema.
|
3 |
+
|
4 |
+
def SearchTopKDocuments(collection, query_text, model, top_k=10):
|
5 |
+
|
6 |
+
# Generate embedding for the query text
|
7 |
+
query_embedding = model.encode(query_text, convert_to_numpy=True)
|
8 |
+
|
9 |
+
# Define search parameters
|
10 |
+
search_params = {
|
11 |
+
"metric_type": "COSINE", # Similarity metric
|
12 |
+
"params": {"ef": 64} # Controls recall, higher values = better accuracy but slower
|
13 |
+
}
|
14 |
+
|
15 |
+
# Perform the search
|
16 |
+
results = collection.search(
|
17 |
+
data=[query_embedding],
|
18 |
+
anns_field="chunk_embedding", # Field containing the embeddings
|
19 |
+
param=search_params,
|
20 |
+
limit=top_k,
|
21 |
+
output_fields=[
|
22 |
+
"chunk_doc_id", # Primary key
|
23 |
+
"doc_id", # Document ID
|
24 |
+
"context_relevance", # Context Relevance Score
|
25 |
+
"context_utilization", # Context Utilization Score
|
26 |
+
"adherence", # Adherence Score
|
27 |
+
"dataset_name", # Dataset Name
|
28 |
+
"relevance_score", # Relevance Score
|
29 |
+
"utilization_score", # Utilization Score
|
30 |
+
"completeness_score" # Completeness Score
|
31 |
+
]
|
32 |
+
)
|
33 |
+
|
34 |
+
# Process and return the results
|
35 |
+
top_documents = []
|
36 |
+
for hits in results:
|
37 |
+
for hit in hits:
|
38 |
+
doc = {
|
39 |
+
"chunk_doc_id": hit.entity.get("chunk_doc_id"), # Primary key
|
40 |
+
"doc_id": hit.entity.get("doc_id"), # Document ID
|
41 |
+
"context_relevance": hit.entity.get("context_relevance"), # Context Relevance Score
|
42 |
+
"context_utilization": hit.entity.get("context_utilization"), # Context Utilization Score
|
43 |
+
"adherence": hit.entity.get("adherence"), # Adherence Score
|
44 |
+
"dataset_name": hit.entity.get("dataset_name"), # Dataset Name
|
45 |
+
"relevance_score": hit.entity.get("relevance_score"), # Relevance Score
|
46 |
+
"utilization_score": hit.entity.get("utilization_score"), # Utilization Score
|
47 |
+
"completeness_score": hit.entity.get("completeness_score"), # Completeness Score
|
48 |
+
"distance": hit.distance # Similarity score (cosine distance)
|
49 |
+
}
|
50 |
+
top_documents.append(doc)
|
51 |
+
|
52 |
+
return top_documents
|