Saiteja Solleti commited on
Commit
9a8353d
·
1 Parent(s): e8e78ae

milvas insert and search addition

Browse files
app.py CHANGED
@@ -3,16 +3,28 @@ import os
3
 
4
  from loaddataset import ExtractRagBenchData
5
  from createmilvusschema import CreateMilvusDbSchema
 
 
 
 
6
  from model import generate_response
7
  from huggingface_hub import login
8
  from huggingface_hub import whoami
9
  from huggingface_hub import dataset_info
10
 
11
 
 
 
 
 
 
 
 
12
  hf_token = os.getenv("HF_TOKEN")
13
  login(hf_token)
14
 
15
  rag_extracted_data = ExtractRagBenchData()
 
16
 
17
  #invoke create milvus db function
18
  try:
@@ -20,12 +32,23 @@ try:
20
  except Exception as e:
21
  print(f"Error creating Milvus DB schema: {e}")
22
 
23
- print(rag_extracted_data.head(5))
 
 
 
 
 
 
 
 
24
 
25
  def chatbot(prompt):
26
  return whoami()
27
 
28
- iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="Capstone Project Group 10")
 
 
 
29
 
30
  if __name__ == "__main__":
31
  iface.launch()
 
3
 
4
  from loaddataset import ExtractRagBenchData
5
  from createmilvusschema import CreateMilvusDbSchema
6
+ from crudmilvushelper import EmbedAllDocumentsAndInsert
7
+ from sentence_transformers import SentenceTransformer
8
+ from searchmilvushelper import SearchTopKDocuments
9
+
10
  from model import generate_response
11
  from huggingface_hub import login
12
  from huggingface_hub import whoami
13
  from huggingface_hub import dataset_info
14
 
15
 
16
+ # Load embedding model
17
+ QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
18
+ WINDOW_SIZE = 5
19
+ OVERLAP = 2
20
+ RETRIVE_TOP_K_SIZE=10
21
+
22
+
23
  hf_token = os.getenv("HF_TOKEN")
24
  login(hf_token)
25
 
26
  rag_extracted_data = ExtractRagBenchData()
27
+ print(rag_extracted_data.head(5))
28
 
29
  #invoke create milvus db function
30
  try:
 
32
  except Exception as e:
33
  print(f"Error creating Milvus DB schema: {e}")
34
 
35
+ #insert embdeding to milvus db
36
+ """
37
+ EmbedAllDocumentsAndInsert(QUERY_EMBEDDING_MODEL, rag_extracted_data, db_collection, window_size=WINDOW_SIZE, overlap=OVERLAP)
38
+ """
39
+ query = "what would the net revenue have been in 2015 if there wasn't a stipulated settlement from the business combination in october 2015?"
40
+
41
+ results_for_top5_chunks = SearchTopKDocuments(db_collection, query, QUERY_EMBEDDING_MODEL, top_k=RETRIVE_TOP_K_SIZE)
42
+ print(results_for_top5_chunks)
43
+
44
 
45
  def chatbot(prompt):
46
  return whoami()
47
 
48
+ iface = gr.Interface(fn=chatbot,
49
+ inputs="text",
50
+ outputs="text",
51
+ title="Capstone Project Group 10")
52
 
53
  if __name__ == "__main__":
54
  iface.launch()
createmilvusschema.py CHANGED
@@ -5,7 +5,6 @@ milvus_token = os.getenv("MILVUS_TOKEN")
5
 
6
  COLLECTION_NAME = "final_ragbench_document_embeddings"
7
  MILVUS_CLOUD_URI = "https://in03-7b4da1b7b588a88.serverless.gcp-us-west1.cloud.zilliz.com"
8
- connections.connect("default", uri=MILVUS_CLOUD_URI, token=milvus_token)
9
 
10
  #Function to create milvus db schema to insert the data
11
  def CreateMilvusDbSchema():
 
5
 
6
  COLLECTION_NAME = "final_ragbench_document_embeddings"
7
  MILVUS_CLOUD_URI = "https://in03-7b4da1b7b588a88.serverless.gcp-us-west1.cloud.zilliz.com"
 
8
 
9
  #Function to create milvus db schema to insert the data
10
  def CreateMilvusDbSchema():
crudmilvus.py DELETED
@@ -1,13 +0,0 @@
1
- import os
2
- from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection
3
- from sentence_transformers import SentenceTransformer
4
-
5
- milvus_token = os.getenv("MILVUS_TOKEN")
6
-
7
- COLLECTION_NAME = "final_ragbench_document_embeddings"
8
- MILVUS_CLOUD_URI = "https://in03-7b4da1b7b588a88.serverless.gcp-us-west1.cloud.zilliz.com"
9
- connections.connect("default", uri=MILVUS_CLOUD_URI, token=milvus_token)
10
-
11
- # Verify connection
12
- print(connections.get_connection_addr("default"))
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
insertmilvushelper.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import pandas as pd
3
+ import numpy as np
4
+ nltk.data.path.append("/content/nltk_data")
5
+ nltk.download('punkt')
6
+ nltk.download('wordnet')
7
+ nltk.download('omw-1.4')
8
+ nltk.download('punkt_tab')
9
+
10
+ from nltk.tokenize import sent_tokenize
11
+
12
+
13
+
14
+
15
+
16
+ #Splits a list of sentences into overlapping chunks using a sliding window approach.
17
+ #sentences (list): List of sentences to split into chunks.
18
+ # window_size (int): Number of sentences in each chunk. Default is 6.
19
+ # overlap (int): Number of overlapping sentences between consecutive chunks. Default is 3.
20
+ # Returns:
21
+ # list: List of text chunks, where each chunk is a string of concatenated sentences.
22
+
23
+ def split_into_sliding_windows(sentences, window_size=6, overlap=3):
24
+
25
+ # Validate input parameters
26
+ if window_size <= overlap:
27
+ raise ValueError("window_size must be greater than overlap.")
28
+ if not sentences:
29
+ return []
30
+
31
+ chunks = []
32
+ step = window_size - overlap # How much to move the window each time
33
+
34
+ # Iterate over the sentences with the specified step size
35
+ for i in range(0, len(sentences), step):
36
+ chunk = sentences[i:i + window_size]
37
+ if len(chunk) >= overlap: # Ensure chunks have minimum required overlap
38
+ chunks.append(" ".join(chunk)) # Join sentences into a text block
39
+
40
+ return chunks
41
+
42
+ # Processes documents using a sliding window approach and inserts sentence chunks into Milvus.
43
+ #Args: model: The embedding model used to generate document embeddings.
44
+ # extracted_data: Pandas DataFrame containing the extracted data.
45
+ # collectionInstance: Milvus collection instance to insert data into.
46
+ # window_size: Number of sentences in each chunk.
47
+ # overlap: Number of overlapping sentences between consecutive chunks.
48
+ #
49
+
50
+ def EmbedAllDocumentsAndInsert(model, extracted_data, collectionInstance, window_size=5, overlap=2):
51
+
52
+ count = 0
53
+ total_docs = len(extracted_data)
54
+ print(f"Total documents: {total_docs}")
55
+
56
+ for index, row in extracted_data.iterrows():
57
+ document = row["documents"] # Extract the document text
58
+ doc_id = row["id"] # Extract the document ID
59
+ doccontextrel = row["gpt3_context_relevance"] # Extract context relevance score
60
+ doccontextutil = row["gpt35_utilization"] # Extract context utilization score
61
+ docadherence = row["gpt3_adherence"] # Extract adherence score
62
+ datasetname = row["dataset_name"] # Extract dataset name
63
+ relevance_score = row["relevance_score"] # Extract relevance score
64
+ utilization_score = row["utilization_score"] # Extract utilization score
65
+ completeness_score = row["completeness_score"] # Extract completeness score
66
+
67
+
68
+ if isinstance(document, list):
69
+ # Flatten the list into a single string
70
+ document = " ".join([str(item) for item in document if isinstance(item, str)])
71
+ elif not isinstance(document, str):
72
+ # If the document is not a string or list, convert it to a string
73
+ document = str(document)
74
+
75
+ # Step 1: Tokenize document into sentences
76
+ sentences = sent_tokenize(document) if isinstance(document, str) else document
77
+
78
+ # Step 2: Generate overlapping chunks
79
+ chunks = split_into_sliding_windows(sentences, window_size, overlap)
80
+
81
+ print(f"Total chunks for document {index}: {len(chunks)}")
82
+
83
+ for chunk_index, chunk_text in enumerate(chunks):
84
+ # Step 3: Generate embedding for each chunk
85
+ chunk_vector = np.array(model.encode(chunk_text), dtype=np.float32).flatten().tolist()
86
+
87
+ print(f"chunk_index= {chunk_index}")
88
+
89
+ # Step 4: Insert chunk into Milvus as separate columns
90
+ insert_embeddings_into_milvus(
91
+ collectionInstance,
92
+ chunk_vector,
93
+ f"{chunk_index}__{doc_id}", # Unique ID for chunk
94
+ doc_id, # Unique ID for doc
95
+ index,
96
+ float(doccontextrel) if pd.notna(doccontextrel) else 0.0, # Handle NaN values
97
+ float(doccontextutil) if pd.notna(doccontextutil) else 0.0, # Handle NaN values
98
+ float(docadherence) if pd.notna(docadherence) else 0.0, # Handle NaN values
99
+ datasetname, # Dataset name column
100
+ float(relevance_score) if pd.notna(relevance_score) else 0.0, # Handle NaN values
101
+ float(utilization_score) if pd.notna(utilization_score) else 0.0, # Handle NaN values
102
+ float(completeness_score) if pd.notna(completeness_score) else 0.0 # Handle NaN values
103
+ )
104
+
105
+ count += 1
106
+ if count % 1000 == 0:
107
+ print(f"Uploaded {count} chunks to Milvus.")
108
+
109
+ # Inserts document embeddings into Milvus along with metadata.
110
+ #Args:
111
+ # collection: Milvus collection instance.
112
+ # embeddings: Embedding vector for the chunk.
113
+ # chunk_doc_id: Unique ID for the chunk.
114
+ # doc_id: Unique ID for the document.
115
+ # index: Index of the document in the dataset.
116
+ # doccontextrel: Context relevance score.
117
+ # doccontextutil: Context utilization score.
118
+ # docadherence: Adherence score.
119
+ # datasetname: Name of the dataset.
120
+
121
+ def insert_embeddings_into_milvus(collection, embeddings, chunk_doc_id, doc_id, index,
122
+ doccontextrel, doccontextutil, docadherence, datasetname,
123
+ relevance_score, utilization_score, completeness_score):
124
+
125
+ try:
126
+ print(f"Inserting chunk {chunk_doc_id} doc {doc_id} (index {index})")
127
+ insert_data = [
128
+ [str(chunk_doc_id)], # Primary key field (document_id)
129
+ [str(doc_id)], # Document ID field
130
+ [embeddings], # Vector field (embedding)
131
+ [float(doccontextrel)], # Relevance score field
132
+ [float(doccontextutil)], # Utilization score field
133
+ [float(docadherence)], # Adherence score field
134
+ [str(datasetname)], # Dataset name field
135
+ [float(relevance_score)], # Relevance score field
136
+ [float(utilization_score)], # Utilization score field
137
+ [float(completeness_score)] # Completeness score field
138
+ ]
139
+ collection.insert(insert_data)
140
+ except Exception as e:
141
+ print(f"Error inserting chunk {chunk_doc_id} doc {doc_id} (index {index}): {e}")
142
+
143
+
loaddataset.py CHANGED
@@ -1,7 +1,7 @@
1
  import pandas as pd
2
  from datasets import load_dataset
3
  from logger import logger
4
- from typing import Dict, List, Optional
5
 
6
 
7
  DATASET_CONFIGS = [
 
1
  import pandas as pd
2
  from datasets import load_dataset
3
  from logger import logger
4
+ from typing import Dict, List
5
 
6
 
7
  DATASET_CONFIGS = [
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio
2
  transformers
3
  torch
4
  huggingface_hub
5
- pymilvus
 
 
2
  transformers
3
  torch
4
  huggingface_hub
5
+ pymilvus
6
+ nltk
searchmilvushelper.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Search Milvus by generating an embedding for the query text. Returns the top_k most similar documents.
2
+ #Retrieves all columns defined in the Milvus schema.
3
+
4
+ def SearchTopKDocuments(collection, query_text, model, top_k=10):
5
+
6
+ # Generate embedding for the query text
7
+ query_embedding = model.encode(query_text, convert_to_numpy=True)
8
+
9
+ # Define search parameters
10
+ search_params = {
11
+ "metric_type": "COSINE", # Similarity metric
12
+ "params": {"ef": 64} # Controls recall, higher values = better accuracy but slower
13
+ }
14
+
15
+ # Perform the search
16
+ results = collection.search(
17
+ data=[query_embedding],
18
+ anns_field="chunk_embedding", # Field containing the embeddings
19
+ param=search_params,
20
+ limit=top_k,
21
+ output_fields=[
22
+ "chunk_doc_id", # Primary key
23
+ "doc_id", # Document ID
24
+ "context_relevance", # Context Relevance Score
25
+ "context_utilization", # Context Utilization Score
26
+ "adherence", # Adherence Score
27
+ "dataset_name", # Dataset Name
28
+ "relevance_score", # Relevance Score
29
+ "utilization_score", # Utilization Score
30
+ "completeness_score" # Completeness Score
31
+ ]
32
+ )
33
+
34
+ # Process and return the results
35
+ top_documents = []
36
+ for hits in results:
37
+ for hit in hits:
38
+ doc = {
39
+ "chunk_doc_id": hit.entity.get("chunk_doc_id"), # Primary key
40
+ "doc_id": hit.entity.get("doc_id"), # Document ID
41
+ "context_relevance": hit.entity.get("context_relevance"), # Context Relevance Score
42
+ "context_utilization": hit.entity.get("context_utilization"), # Context Utilization Score
43
+ "adherence": hit.entity.get("adherence"), # Adherence Score
44
+ "dataset_name": hit.entity.get("dataset_name"), # Dataset Name
45
+ "relevance_score": hit.entity.get("relevance_score"), # Relevance Score
46
+ "utilization_score": hit.entity.get("utilization_score"), # Utilization Score
47
+ "completeness_score": hit.entity.get("completeness_score"), # Completeness Score
48
+ "distance": hit.distance # Similarity score (cosine distance)
49
+ }
50
+ top_documents.append(doc)
51
+
52
+ return top_documents