import torch import torch.nn as nn from transformers import PreTrainedModel import logging import floret import os from huggingface_hub import hf_hub_download from .configuration_ocrqa import ImpressoConfig logger = logging.getLogger(__name__) from pybloomfilter import BloomFilter from transformers import pipeline import unicodedata from typing import Optional QUOTES_PUNCT = "„•<>!\"#%&'’" ASCII_PUNCT = "()*,./:;?" BRACKETS_SPECIAL = "[]\\~_{}" UNICODE_PUNCT = "\xa1\xab\xb7\xbb\xbf" DASH_CARET = "—^`" SPECIAL_SYMBOLS = "¦§£=" HYPHEN = "-" DIGITS = "0123456789" NORMALIZATION_TABLE = str.maketrans( { char: " " for char in ( QUOTES_PUNCT + ASCII_PUNCT + BRACKETS_SPECIAL + UNICODE_PUNCT + DASH_CARET + SPECIAL_SYMBOLS + HYPHEN ) } | {char: "0" for char in DIGITS} ) def normalize_text(s: str, unicode_normalize: Optional[str] = "NFKC") -> str: """Normalize text by replacing punctuation with spaces and digits with '0'.""" if unicode_normalize: s = unicodedata.normalize(unicode_normalize, s).lower() return s.translate(NORMALIZATION_TABLE) def filter_text(text: str, bloom_filter: BloomFilter): knowns = set() unknowns = set() # Normalize and tokenize text normalized_text = normalize_text(text) tokens = normalized_text.split() # Check tokens against the bloom filter for token in tokens: if token in bloom_filter: # print(f"'{token}' is in the bloom filter.") knowns.add(token) else: # print(f"'{token}' is NOT in the bloom filter.") unknowns.add(token) result = {"known": knowns, "unknown": unknowns} return result class QAAssessmentModel(PreTrainedModel): config_class = ImpressoConfig def get_bloomfilter(self, model_id: str, filename: str): return BloomFilter.open(hf_hub_download(repo_id=model_id, filename=filename)) def __init__(self, config): super().__init__(config) self.config = config # Dummy for device checking self.dummy_param = nn.Parameter(torch.zeros(1)) bin_filenames = {"en": self.config.config.filename["en"], "de": self.config.config.filename["de"], "fr": self.config.config.filename["fr"], "other": self.config.config.filename["other"]} self.ocrqa_assessors = {} # model_filename = self.config.config.model[lang] for lang in bin_filenames.keys(): model_filename = self.config.config.filename[lang] print(f"Loading model for {lang}: {model_filename}") # if not os.path.exists(model_filename): # print(f"{bin_filename} not found locally, downloading from Hugging Face hub...") self.ocrqa_assessors[lang] = self.get_bloomfilter(model_id=self.config.config._name_or_path, filename=model_filename) print(self.ocrqa_assessors) self.lang_pipeline = pipeline("langident", model="impresso-project/language-identifier", trust_remote_code=True, device="cpu") def forward(self, input_ids, **kwargs): if isinstance(input_ids, str): # If the input is a single string, make it a list for floret texts = [input_ids] elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): texts = input_ids else: raise ValueError(f"Unexpected input type: {type(input_ids)}") predictions, probabilities = [], [] for text in texts: langs = self.lang_pipeline(input_ids) # [{'label': 'fr', 'confidence': 99.87}] if len(langs) > 0: print(f"Detected languages: {langs}") lang = langs['language'] logger.info(f"Detected language: {lang}") else: lang = "other" logger.warning("Language detection failed, using 'other' as default.") if lang not in self.ocrqa_assessors: logger.warning(f"Language '{lang}' not found in bin_filename, using 'other' as default.") lang = "other" # Process the text using the selected filter result = filter_text(text, self.ocrqa_assessors[lang]) known_count = len(result["known"]) unknown_count = len(result["unknown"]) # Compute quality score percentage score = (known_count / (known_count + unknown_count + 0.000001)) # * 100 predictions.append(score) return predictions @property def device(self): return next(self.parameters()).device @classmethod def from_pretrained(cls, *args, **kwargs): # print("Ignoring weights and using custom initialization.") # Manually create the config config = ImpressoConfig(**kwargs) # Pass the manually created config to the class model = cls(config) return model