Spaces:
Sleeping
Sleeping
import pandas as pd | |
import json | |
import sys | |
import os | |
from collections import defaultdict | |
from util.vector_base import EmbeddingFunction, get_or_create_vector_base | |
from util.Embeddings import TextEmb3LargeEmbedding | |
from langchain_core.documents import Document | |
from FlagEmbedding import FlagReranker | |
import time | |
# from bm25s import BM25, tokenize | |
import contextlib | |
import io | |
from tqdm import tqdm | |
def rrf(rankings, k = 60): | |
res = 0 | |
for r in rankings: | |
res += 1 / (r + k) | |
return res | |
def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True): | |
final_result = [] | |
if not if_split_po: | |
final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k) | |
else: | |
for po in PO: | |
po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k) | |
for safeguard in po_result: | |
final_result.append(safeguard) | |
return final_result | |
def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20): | |
""" | |
requirements_dict: [ | |
requirement: { | |
"PO": [], | |
"safeguard": [] | |
} | |
] | |
""" | |
candidate_safeguards = [] | |
po_list = [po.lower().rstrip() for po in PO if po] | |
if "young users" in po_list and len(po_list) == 1: | |
return [] | |
candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}}) | |
safeguard_dict, safeguard_content = {}, [] | |
for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']): | |
safeguard_dict[content] = { | |
"metadata": metadata, | |
"rank": [], | |
"rrf_score": 0 | |
} | |
safeguard_content.append(content) | |
# Reranker | |
if using_reranker: | |
content_pairs, reranking_rank, reranking_results = [], [], [] | |
for safeguard in safeguard_content: | |
content_pairs.append([requirement, safeguard]) | |
safeguard_rerank_scores = reranker_model.compute_score(content_pairs) | |
for content_pair, score in zip(content_pairs, safeguard_rerank_scores): | |
reranking_rank.append((content_pair[1], score)) | |
reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True) | |
for safeguard, score in reranking_results: | |
safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1) | |
# BM25 | |
if using_BM25: | |
with contextlib.redirect_stdout(io.StringIO()): | |
bm25_retriever = BM25(corpus=safeguard_content) | |
bm25_retriever.index(tokenize(safeguard_content)) | |
bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content)) | |
bm25_retrieval_rank = 1 | |
for safeguard in bm25_results[0]: | |
safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank) | |
bm25_retrieval_rank += 1 | |
# chroma retrieval | |
if using_chroma: | |
retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}}) | |
retrieval_rank = 1 | |
for safeguard in retrieved_safeguards: | |
safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank) | |
retrieval_rank += 1 | |
final_result = [] | |
for safeguard in safeguard_content: | |
safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank']) | |
final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po'])) | |
final_result.sort(key=lambda x: x[0], reverse=True) | |
# top k | |
topk_final_result = final_result[:k] | |
return topk_final_result | |
if __name__=="__main__": | |
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58) | |
embedding = EmbeddingFunction(embeddingmodel) | |
safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding) | |
reranker_model = FlagReranker( | |
'/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3', | |
use_fp16=True, | |
devices=["cpu"], | |
) | |
requirement = """ | |
Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure. | |
""" | |
PO = ["Data Minimization & Purpose Limitation", "Transparency"] | |
final_result = retriever( | |
requirement, | |
PO, | |
safeguard_vector_store, | |
reranker_model, | |
using_reranker=True, | |
using_BM25=False, | |
using_chroma=True, | |
k=10 | |
) | |
print(final_result) |