NLL_Interface / retriever.py
bytedancerneat's picture
Update retriever.py
1de832e verified
import pandas as pd
import json
import sys
import os
from collections import defaultdict
from util.vector_base import EmbeddingFunction, get_or_create_vector_base
from util.Embeddings import TextEmb3LargeEmbedding
from langchain_core.documents import Document
from FlagEmbedding import FlagReranker
import time
# from bm25s import BM25, tokenize
import contextlib
import io
from tqdm import tqdm
def rrf(rankings, k = 60):
res = 0
for r in rankings:
res += 1 / (r + k)
return res
def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True):
final_result = []
if not if_split_po:
final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
else:
for po in PO:
po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
for safeguard in po_result:
final_result.append(safeguard)
return final_result
def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20):
"""
requirements_dict: [
requirement: {
"PO": [],
"safeguard": []
}
]
"""
candidate_safeguards = []
po_list = [po.lower().rstrip() for po in PO if po]
if "young users" in po_list and len(po_list) == 1:
return []
candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}})
safeguard_dict, safeguard_content = {}, []
for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']):
safeguard_dict[content] = {
"metadata": metadata,
"rank": [],
"rrf_score": 0
}
safeguard_content.append(content)
# Reranker
if using_reranker:
content_pairs, reranking_rank, reranking_results = [], [], []
for safeguard in safeguard_content:
content_pairs.append([requirement, safeguard])
safeguard_rerank_scores = reranker_model.compute_score(content_pairs)
for content_pair, score in zip(content_pairs, safeguard_rerank_scores):
reranking_rank.append((content_pair[1], score))
reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True)
for safeguard, score in reranking_results:
safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1)
# BM25
if using_BM25:
with contextlib.redirect_stdout(io.StringIO()):
bm25_retriever = BM25(corpus=safeguard_content)
bm25_retriever.index(tokenize(safeguard_content))
bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content))
bm25_retrieval_rank = 1
for safeguard in bm25_results[0]:
safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank)
bm25_retrieval_rank += 1
# chroma retrieval
if using_chroma:
retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}})
retrieval_rank = 1
for safeguard in retrieved_safeguards:
safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank)
retrieval_rank += 1
final_result = []
for safeguard in safeguard_content:
safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank'])
final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po']))
final_result.sort(key=lambda x: x[0], reverse=True)
# top k
topk_final_result = final_result[:k]
return topk_final_result
if __name__=="__main__":
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
embedding = EmbeddingFunction(embeddingmodel)
safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
reranker_model = FlagReranker(
'/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3',
use_fp16=True,
devices=["cpu"],
)
requirement = """
Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure.
"""
PO = ["Data Minimization & Purpose Limitation", "Transparency"]
final_result = retriever(
requirement,
PO,
safeguard_vector_store,
reranker_model,
using_reranker=True,
using_BM25=False,
using_chroma=True,
k=10
)
print(final_result)