import pandas as pd import json import sys import os from collections import defaultdict from util.vector_base import EmbeddingFunction, get_or_create_vector_base from util.Embeddings import TextEmb3LargeEmbedding from langchain_core.documents import Document from FlagEmbedding import FlagReranker import time # from bm25s import BM25, tokenize import contextlib import io from tqdm import tqdm def rrf(rankings, k = 60): res = 0 for r in rankings: res += 1 / (r + k) return res def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True): final_result = [] if not if_split_po: final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k) else: for po in PO: po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k) for safeguard in po_result: final_result.append(safeguard) return final_result def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20): """ requirements_dict: [ requirement: { "PO": [], "safeguard": [] } ] """ candidate_safeguards = [] po_list = [po.lower().rstrip() for po in PO if po] if "young users" in po_list and len(po_list) == 1: return [] candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}}) safeguard_dict, safeguard_content = {}, [] for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']): safeguard_dict[content] = { "metadata": metadata, "rank": [], "rrf_score": 0 } safeguard_content.append(content) # Reranker if using_reranker: content_pairs, reranking_rank, reranking_results = [], [], [] for safeguard in safeguard_content: content_pairs.append([requirement, safeguard]) safeguard_rerank_scores = reranker_model.compute_score(content_pairs) for content_pair, score in zip(content_pairs, safeguard_rerank_scores): reranking_rank.append((content_pair[1], score)) reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True) for safeguard, score in reranking_results: safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1) # BM25 if using_BM25: with contextlib.redirect_stdout(io.StringIO()): bm25_retriever = BM25(corpus=safeguard_content) bm25_retriever.index(tokenize(safeguard_content)) bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content)) bm25_retrieval_rank = 1 for safeguard in bm25_results[0]: safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank) bm25_retrieval_rank += 1 # chroma retrieval if using_chroma: retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}}) retrieval_rank = 1 for safeguard in retrieved_safeguards: safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank) retrieval_rank += 1 final_result = [] for safeguard in safeguard_content: safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank']) final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po'])) final_result.sort(key=lambda x: x[0], reverse=True) # top k topk_final_result = final_result[:k] return topk_final_result if __name__=="__main__": embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58) embedding = EmbeddingFunction(embeddingmodel) safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding) reranker_model = FlagReranker( '/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3', use_fp16=True, devices=["cpu"], ) requirement = """ Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure. """ PO = ["Data Minimization & Purpose Limitation", "Transparency"] final_result = retriever( requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=10 ) print(final_result)