Spaces:
Sleeping
Sleeping
Update util/vector_base.py
Browse files- util/vector_base.py +79 -78
util/vector_base.py
CHANGED
@@ -1,79 +1,80 @@
|
|
1 |
-
import sys
|
2 |
-
from langchain_chroma import Chroma
|
3 |
-
from langchain_core.documents import Document
|
4 |
-
sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
|
5 |
-
|
6 |
-
from
|
7 |
-
import
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
requirement =
|
53 |
-
requirement = requirement
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
'
|
58 |
-
|
59 |
-
|
60 |
-
requirements_dict_v2[requirement]['
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
"
|
68 |
-
"
|
69 |
-
"
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
79 |
requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)
|
|
|
1 |
+
import sys
|
2 |
+
from langchain_chroma import Chroma
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
# sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
|
5 |
+
sys.path.append('/home/user/app/util')
|
6 |
+
from Embeddings import TextEmb3LargeEmbedding
|
7 |
+
from pathlib import Path
|
8 |
+
import time
|
9 |
+
|
10 |
+
class EmbeddingFunction():
|
11 |
+
def __init__(self, embeddingmodel):
|
12 |
+
self.embeddingmodel = embeddingmodel
|
13 |
+
def embed_query(self, query):
|
14 |
+
return list(self.embeddingmodel.get_embedding(query))
|
15 |
+
def embed_documents(self, documents):
|
16 |
+
return [self.embeddingmodel.get_embedding(document) for document in documents]
|
17 |
+
|
18 |
+
def get_or_create_vector_base(collection_name: str, embedding, documents=None) -> Chroma:
|
19 |
+
"""
|
20 |
+
判断vector store是否已经构建好,如果没有构建好,则先初始化vector store。不使用embed_documents
|
21 |
+
方法批量初始化vector store而是for循环逐个加入,同时使用sleep,以此避免调用openai的接口达到最大
|
22 |
+
上限而导致初始化失败。
|
23 |
+
"""
|
24 |
+
persist_directory = "C://Users//Admin//Desktop//PDPO//NLL_LLM//store//" +collection_name
|
25 |
+
persist_path = Path(persist_directory)
|
26 |
+
if not persist_path.exists and not documents:
|
27 |
+
raise ValueError("vector store does not exist and documents is empty")
|
28 |
+
elif persist_path.exists():
|
29 |
+
print("vector store already exists")
|
30 |
+
vector_store = Chroma(
|
31 |
+
collection_name=collection_name,
|
32 |
+
embedding_function=embedding,
|
33 |
+
persist_directory=persist_directory
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
print("start creating vector store")
|
37 |
+
vector_store = Chroma(
|
38 |
+
collection_name=collection_name,
|
39 |
+
embedding_function=embedding,
|
40 |
+
persist_directory=persist_directory
|
41 |
+
)
|
42 |
+
for document in documents:
|
43 |
+
vector_store.add_documents(documents=[document])
|
44 |
+
time.sleep(1)
|
45 |
+
return vector_store
|
46 |
+
|
47 |
+
if __name__=="__main__":
|
48 |
+
import pandas as pd
|
49 |
+
requirements_data = pd.read_csv("/root/PTR-LLM/tasks/pcf/reference/NLL_DATA_NEW_Test.csv")
|
50 |
+
requirements_dict_v2 = {}
|
51 |
+
for index, row in requirements_data.iterrows():
|
52 |
+
requirement = row['Requirement'].split("- ")[1]
|
53 |
+
requirement = requirement + ": " + row['Details']
|
54 |
+
requirement = requirement.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
|
55 |
+
if requirement not in requirements_dict_v2:
|
56 |
+
requirements_dict_v2[requirement] = {
|
57 |
+
'PO': set(),
|
58 |
+
'safeguard': set()
|
59 |
+
}
|
60 |
+
requirements_dict_v2[requirement]['PO'].add(row['PCF-Privacy Objective'].lower().rstrip() if isinstance(row['PCF-Privacy Objective'], str) else None)
|
61 |
+
requirements_dict_v2[requirement]['safeguard'].add(row['Safeguard'].lower().rstrip())
|
62 |
+
index = 0
|
63 |
+
documents = []
|
64 |
+
for key, value in requirements_dict_v2.items():
|
65 |
+
page_content = key
|
66 |
+
metadata = {
|
67 |
+
"index": index,
|
68 |
+
"version":2,
|
69 |
+
"PO": str([po for po in value['PO'] if po]),
|
70 |
+
"safeguard":str([safeguard for safeguard in value['safeguard']])
|
71 |
+
}
|
72 |
+
index += 1
|
73 |
+
document=Document(
|
74 |
+
page_content=page_content,
|
75 |
+
metadata=metadata
|
76 |
+
)
|
77 |
+
documents.append(document)
|
78 |
+
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
|
79 |
+
embedding = EmbeddingFunction(embeddingmodel)
|
80 |
requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)
|