File size: 5,255 Bytes
b448ba9
 
 
 
 
 
 
197e8c2
b448ba9
 
c4eb043
 
cdd5fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4eb043
 
cdd5fd4
 
 
 
c4eb043
cdd5fd4
 
 
b448ba9
cdd5fd4
 
 
 
 
 
 
 
 
 
b448ba9
53e37f7
b448ba9
 
c4eb043
 
 
b448ba9
 
 
 
 
 
cdd5fd4
 
 
 
 
 
27d03da
cdd5fd4
27d03da
cdd5fd4
88c1917
6211123
cdd5fd4
6211123
 
 
27d03da
b828aa0
70b421f
b828aa0
 
b448ba9
 
 
 
 
 
 
 
 
 
cdd5fd4
 
b828aa0
cdd5fd4
 
ba1cf57
777cb0a
cdd5fd4
 
 
 
 
 
 
 
 
6c25308
 
 
 
 
 
 
 
cdd5fd4
 
b448ba9
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
from transformers import PreTrainedModel
import logging
import floret
import os
from huggingface_hub import hf_hub_download
from .configuration_ocrqa import ImpressoConfig

logger = logging.getLogger(__name__)
from pybloomfilter import BloomFilter
from transformers import pipeline
import unicodedata
from typing import Optional

QUOTES_PUNCT = "„•<>!\"#%&'’"
ASCII_PUNCT = "()*,./:;?"
BRACKETS_SPECIAL = "[]\\~_{}"
UNICODE_PUNCT = "\xa1\xab\xb7\xbb\xbf"
DASH_CARET = "—^`"
SPECIAL_SYMBOLS = "¦§£="
HYPHEN = "-"
DIGITS = "0123456789"

NORMALIZATION_TABLE = str.maketrans(
    {
        char: " "
        for char in (
            QUOTES_PUNCT
            + ASCII_PUNCT
            + BRACKETS_SPECIAL
            + UNICODE_PUNCT
            + DASH_CARET
            + SPECIAL_SYMBOLS
            + HYPHEN
        )
    }
    | {char: "0" for char in DIGITS}
)


def normalize_text(s: str, unicode_normalize: Optional[str] = "NFKC") -> str:
    """Normalize text by replacing punctuation with spaces and digits with '0'."""
    if unicode_normalize:
        s = unicodedata.normalize(unicode_normalize, s).lower()
    return s.translate(NORMALIZATION_TABLE)


def filter_text(text: str, bloom_filter: BloomFilter):

    knowns = set()
    unknowns = set()

    # Normalize and tokenize text
    normalized_text = normalize_text(text)
    tokens = normalized_text.split()

    # Check tokens against the bloom filter
    for token in tokens:
        if token in bloom_filter:
            # print(f"'{token}' is in the bloom filter.")
            knowns.add(token)
        else:
            # print(f"'{token}' is NOT in the bloom filter.")
            unknowns.add(token)
    result = {"known": knowns, "unknown": unknowns}
    return result

class QAAssessmentModel(PreTrainedModel):
    config_class = ImpressoConfig

    def get_bloomfilter(self, model_id: str, filename: str):
        return BloomFilter.open(hf_hub_download(repo_id=model_id, filename=filename))

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # Dummy for device checking
        self.dummy_param = nn.Parameter(torch.zeros(1))
        bin_filenames = {"en": self.config.config.filename["en"],
                        "de": self.config.config.filename["de"],
                        "fr": self.config.config.filename["fr"],
                        "other": self.config.config.filename["other"]}

        self.ocrqa_assessors = {}

        # model_filename = self.config.config.model[lang]
        for lang in bin_filenames.keys():
            model_filename = self.config.config.filename[lang]
            print(f"Loading model for {lang}: {model_filename}")
            # if not os.path.exists(model_filename):
                # print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
            self.ocrqa_assessors[lang] = self.get_bloomfilter(model_id=self.config.config._name_or_path,
                                                              filename=model_filename)

        print(self.ocrqa_assessors)
        self.lang_pipeline = pipeline("langident",
                                    model="impresso-project/language-identifier",
                                    trust_remote_code=True,
                                    device="cpu")

    def forward(self, input_ids, **kwargs):
        if isinstance(input_ids, str):
            # If the input is a single string, make it a list for floret
            texts = [input_ids]
        elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
            texts = input_ids
        else:
            raise ValueError(f"Unexpected input type: {type(input_ids)}")

        predictions, probabilities = [], []
        for text in texts:
            langs = self.lang_pipeline(input_ids)
            # [{'label': 'fr', 'confidence': 99.87}]
            if len(langs) > 0:
                print(f"Detected languages: {langs}")
                lang = langs['language']
                logger.info(f"Detected language: {lang}")
            else:
                lang = "other"
                logger.warning("Language detection failed, using 'other' as default.")

            if lang not in self.ocrqa_assessors:
                logger.warning(f"Language '{lang}' not found in bin_filename, using 'other' as default.")
                lang = "other"

            # Process the text using the selected filter
            result = filter_text(text, self.ocrqa_assessors[lang])
            known_count = len(result["known"])
            unknown_count = len(result["unknown"])

            # Compute quality score percentage
            score = (known_count / (known_count + unknown_count + 0.000001))  # * 100
            predictions.append(score)

        return predictions

    @property
    def device(self):
        return next(self.parameters()).device

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        # print("Ignoring weights and using custom initialization.")
        # Manually create the config
        config = ImpressoConfig(**kwargs)
        # Pass the manually created config to the class
        model = cls(config)
        return model