File size: 5,125 Bytes
1de832e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929938f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pandas as pd
import json
import sys
import os
from collections import defaultdict
from util.vector_base import EmbeddingFunction, get_or_create_vector_base
from util.Embeddings import TextEmb3LargeEmbedding
from langchain_core.documents import Document
from FlagEmbedding import FlagReranker
import time
# from bm25s import BM25, tokenize
import contextlib
import io
from tqdm import tqdm

def rrf(rankings, k = 60):
    res = 0
    for r in rankings:
        res += 1 / (r + k)
    return res

def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True):
    final_result = []
    if not if_split_po:
        final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
    else:
        for po in PO:
            po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
            for safeguard in po_result:
                final_result.append(safeguard)
    return final_result

def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20):
    """
    requirements_dict: [
        requirement: {
            "PO": [],
            "safeguard": []
        }
    ]
    """
    candidate_safeguards = []
    po_list = [po.lower().rstrip() for po in PO if po]
    if "young users" in po_list and len(po_list) == 1:
        return []
    candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}})
    safeguard_dict, safeguard_content = {}, []
    for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']):
        safeguard_dict[content] = {
            "metadata": metadata,
            "rank": [],
            "rrf_score": 0
        }
        safeguard_content.append(content)
     
    # Reranker
    if using_reranker:
        content_pairs, reranking_rank, reranking_results = [], [], []
        for safeguard in safeguard_content:
            content_pairs.append([requirement, safeguard])
        safeguard_rerank_scores = reranker_model.compute_score(content_pairs)
        for content_pair, score in zip(content_pairs, safeguard_rerank_scores):
            reranking_rank.append((content_pair[1], score))
        reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True)
        for safeguard, score in reranking_results:
            safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1)

    #  BM25
    if using_BM25:
        with contextlib.redirect_stdout(io.StringIO()):
            bm25_retriever = BM25(corpus=safeguard_content)
            bm25_retriever.index(tokenize(safeguard_content))
            bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content))
            bm25_retrieval_rank = 1
            for safeguard in bm25_results[0]:
                safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank)
                bm25_retrieval_rank += 1

        # chroma retrieval
    if using_chroma:
        retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}})
        retrieval_rank = 1
        for safeguard in retrieved_safeguards:
            safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank)
            retrieval_rank += 1

    final_result = []
    for safeguard in safeguard_content:
        safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank'])
        final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po']))
    final_result.sort(key=lambda x: x[0], reverse=True)
        
        # top k 
    topk_final_result = final_result[:k]

    return topk_final_result

if __name__=="__main__":
    embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
    embedding = EmbeddingFunction(embeddingmodel)
    safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
    reranker_model = FlagReranker(
        '/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3',
        use_fp16=True,
        devices=["cpu"],   
    )
    requirement = """
    Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure.
    """
    PO = ["Data Minimization & Purpose Limitation", "Transparency"]
    final_result = retriever(
            requirement,
            PO,
            safeguard_vector_store,
            reranker_model,
            using_reranker=True,
            using_BM25=False,
            using_chroma=True,
            k=10
            )
    print(final_result)