bytedancerneat commited on
Commit
0ab1e65
·
verified ·
1 Parent(s): d58e2aa

Update util/vector_base.py

Browse files
Files changed (1) hide show
  1. util/vector_base.py +79 -78
util/vector_base.py CHANGED
@@ -1,79 +1,80 @@
1
- import sys
2
- from langchain_chroma import Chroma
3
- from langchain_core.documents import Document
4
- sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
5
- from Embeddings import TextEmb3LargeEmbedding
6
- from pathlib import Path
7
- import time
8
-
9
- class EmbeddingFunction():
10
- def __init__(self, embeddingmodel):
11
- self.embeddingmodel = embeddingmodel
12
- def embed_query(self, query):
13
- return list(self.embeddingmodel.get_embedding(query))
14
- def embed_documents(self, documents):
15
- return [self.embeddingmodel.get_embedding(document) for document in documents]
16
-
17
- def get_or_create_vector_base(collection_name: str, embedding, documents=None) -> Chroma:
18
- """
19
- 判断vector store是否已经构建好,如果没有构建好,则先初始化vector store。不使用embed_documents
20
- 方法批量初始化vector store而是for循环逐个加入,同时使用sleep,以此避免调用openai的接口达到最大
21
- 上限而导致初始化失败。
22
- """
23
- persist_directory = "C://Users//Admin//Desktop//PDPO//NLL_LLM//store//" +collection_name
24
- persist_path = Path(persist_directory)
25
- if not persist_path.exists and not documents:
26
- raise ValueError("vector store does not exist and documents is empty")
27
- elif persist_path.exists():
28
- print("vector store already exists")
29
- vector_store = Chroma(
30
- collection_name=collection_name,
31
- embedding_function=embedding,
32
- persist_directory=persist_directory
33
- )
34
- else:
35
- print("start creating vector store")
36
- vector_store = Chroma(
37
- collection_name=collection_name,
38
- embedding_function=embedding,
39
- persist_directory=persist_directory
40
- )
41
- for document in documents:
42
- vector_store.add_documents(documents=[document])
43
- time.sleep(1)
44
- return vector_store
45
-
46
- if __name__=="__main__":
47
- import pandas as pd
48
- requirements_data = pd.read_csv("/root/PTR-LLM/tasks/pcf/reference/NLL_DATA_NEW_Test.csv")
49
- requirements_dict_v2 = {}
50
- for index, row in requirements_data.iterrows():
51
- requirement = row['Requirement'].split("- ")[1]
52
- requirement = requirement + ": " + row['Details']
53
- requirement = requirement.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
54
- if requirement not in requirements_dict_v2:
55
- requirements_dict_v2[requirement] = {
56
- 'PO': set(),
57
- 'safeguard': set()
58
- }
59
- requirements_dict_v2[requirement]['PO'].add(row['PCF-Privacy Objective'].lower().rstrip() if isinstance(row['PCF-Privacy Objective'], str) else None)
60
- requirements_dict_v2[requirement]['safeguard'].add(row['Safeguard'].lower().rstrip())
61
- index = 0
62
- documents = []
63
- for key, value in requirements_dict_v2.items():
64
- page_content = key
65
- metadata = {
66
- "index": index,
67
- "version":2,
68
- "PO": str([po for po in value['PO'] if po]),
69
- "safeguard":str([safeguard for safeguard in value['safeguard']])
70
- }
71
- index += 1
72
- document=Document(
73
- page_content=page_content,
74
- metadata=metadata
75
- )
76
- documents.append(document)
77
- embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
78
- embedding = EmbeddingFunction(embeddingmodel)
 
79
  requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)
 
1
+ import sys
2
+ from langchain_chroma import Chroma
3
+ from langchain_core.documents import Document
4
+ # sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
5
+ sys.path.append('/home/user/app/util')
6
+ from Embeddings import TextEmb3LargeEmbedding
7
+ from pathlib import Path
8
+ import time
9
+
10
+ class EmbeddingFunction():
11
+ def __init__(self, embeddingmodel):
12
+ self.embeddingmodel = embeddingmodel
13
+ def embed_query(self, query):
14
+ return list(self.embeddingmodel.get_embedding(query))
15
+ def embed_documents(self, documents):
16
+ return [self.embeddingmodel.get_embedding(document) for document in documents]
17
+
18
+ def get_or_create_vector_base(collection_name: str, embedding, documents=None) -> Chroma:
19
+ """
20
+ 判断vector store是否已经构建好,如果没有构建好,则先初始化vector store。不使用embed_documents
21
+ 方法批量初始化vector store而是for循环逐个加入,同时使用sleep,以此避免调用openai的接口达到最大
22
+ 上限而导致初始化失败。
23
+ """
24
+ persist_directory = "C://Users//Admin//Desktop//PDPO//NLL_LLM//store//" +collection_name
25
+ persist_path = Path(persist_directory)
26
+ if not persist_path.exists and not documents:
27
+ raise ValueError("vector store does not exist and documents is empty")
28
+ elif persist_path.exists():
29
+ print("vector store already exists")
30
+ vector_store = Chroma(
31
+ collection_name=collection_name,
32
+ embedding_function=embedding,
33
+ persist_directory=persist_directory
34
+ )
35
+ else:
36
+ print("start creating vector store")
37
+ vector_store = Chroma(
38
+ collection_name=collection_name,
39
+ embedding_function=embedding,
40
+ persist_directory=persist_directory
41
+ )
42
+ for document in documents:
43
+ vector_store.add_documents(documents=[document])
44
+ time.sleep(1)
45
+ return vector_store
46
+
47
+ if __name__=="__main__":
48
+ import pandas as pd
49
+ requirements_data = pd.read_csv("/root/PTR-LLM/tasks/pcf/reference/NLL_DATA_NEW_Test.csv")
50
+ requirements_dict_v2 = {}
51
+ for index, row in requirements_data.iterrows():
52
+ requirement = row['Requirement'].split("- ")[1]
53
+ requirement = requirement + ": " + row['Details']
54
+ requirement = requirement.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
55
+ if requirement not in requirements_dict_v2:
56
+ requirements_dict_v2[requirement] = {
57
+ 'PO': set(),
58
+ 'safeguard': set()
59
+ }
60
+ requirements_dict_v2[requirement]['PO'].add(row['PCF-Privacy Objective'].lower().rstrip() if isinstance(row['PCF-Privacy Objective'], str) else None)
61
+ requirements_dict_v2[requirement]['safeguard'].add(row['Safeguard'].lower().rstrip())
62
+ index = 0
63
+ documents = []
64
+ for key, value in requirements_dict_v2.items():
65
+ page_content = key
66
+ metadata = {
67
+ "index": index,
68
+ "version":2,
69
+ "PO": str([po for po in value['PO'] if po]),
70
+ "safeguard":str([safeguard for safeguard in value['safeguard']])
71
+ }
72
+ index += 1
73
+ document=Document(
74
+ page_content=page_content,
75
+ metadata=metadata
76
+ )
77
+ documents.append(document)
78
+ embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
79
+ embedding = EmbeddingFunction(embeddingmodel)
80
  requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)