File size: 5,255 Bytes
b448ba9 197e8c2 b448ba9 c4eb043 cdd5fd4 c4eb043 cdd5fd4 c4eb043 cdd5fd4 b448ba9 cdd5fd4 b448ba9 53e37f7 b448ba9 c4eb043 b448ba9 cdd5fd4 27d03da cdd5fd4 27d03da cdd5fd4 88c1917 6211123 cdd5fd4 6211123 27d03da b828aa0 70b421f b828aa0 b448ba9 cdd5fd4 b828aa0 cdd5fd4 ba1cf57 777cb0a cdd5fd4 6c25308 cdd5fd4 b448ba9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel
import logging
import floret
import os
from huggingface_hub import hf_hub_download
from .configuration_ocrqa import ImpressoConfig
logger = logging.getLogger(__name__)
from pybloomfilter import BloomFilter
from transformers import pipeline
import unicodedata
from typing import Optional
QUOTES_PUNCT = "„•<>!\"#%&'’"
ASCII_PUNCT = "()*,./:;?"
BRACKETS_SPECIAL = "[]\\~_{}"
UNICODE_PUNCT = "\xa1\xab\xb7\xbb\xbf"
DASH_CARET = "—^`"
SPECIAL_SYMBOLS = "¦§£="
HYPHEN = "-"
DIGITS = "0123456789"
NORMALIZATION_TABLE = str.maketrans(
{
char: " "
for char in (
QUOTES_PUNCT
+ ASCII_PUNCT
+ BRACKETS_SPECIAL
+ UNICODE_PUNCT
+ DASH_CARET
+ SPECIAL_SYMBOLS
+ HYPHEN
)
}
| {char: "0" for char in DIGITS}
)
def normalize_text(s: str, unicode_normalize: Optional[str] = "NFKC") -> str:
"""Normalize text by replacing punctuation with spaces and digits with '0'."""
if unicode_normalize:
s = unicodedata.normalize(unicode_normalize, s).lower()
return s.translate(NORMALIZATION_TABLE)
def filter_text(text: str, bloom_filter: BloomFilter):
knowns = set()
unknowns = set()
# Normalize and tokenize text
normalized_text = normalize_text(text)
tokens = normalized_text.split()
# Check tokens against the bloom filter
for token in tokens:
if token in bloom_filter:
# print(f"'{token}' is in the bloom filter.")
knowns.add(token)
else:
# print(f"'{token}' is NOT in the bloom filter.")
unknowns.add(token)
result = {"known": knowns, "unknown": unknowns}
return result
class QAAssessmentModel(PreTrainedModel):
config_class = ImpressoConfig
def get_bloomfilter(self, model_id: str, filename: str):
return BloomFilter.open(hf_hub_download(repo_id=model_id, filename=filename))
def __init__(self, config):
super().__init__(config)
self.config = config
# Dummy for device checking
self.dummy_param = nn.Parameter(torch.zeros(1))
bin_filenames = {"en": self.config.config.filename["en"],
"de": self.config.config.filename["de"],
"fr": self.config.config.filename["fr"],
"other": self.config.config.filename["other"]}
self.ocrqa_assessors = {}
# model_filename = self.config.config.model[lang]
for lang in bin_filenames.keys():
model_filename = self.config.config.filename[lang]
print(f"Loading model for {lang}: {model_filename}")
# if not os.path.exists(model_filename):
# print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
self.ocrqa_assessors[lang] = self.get_bloomfilter(model_id=self.config.config._name_or_path,
filename=model_filename)
print(self.ocrqa_assessors)
self.lang_pipeline = pipeline("langident",
model="impresso-project/language-identifier",
trust_remote_code=True,
device="cpu")
def forward(self, input_ids, **kwargs):
if isinstance(input_ids, str):
# If the input is a single string, make it a list for floret
texts = [input_ids]
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
texts = input_ids
else:
raise ValueError(f"Unexpected input type: {type(input_ids)}")
predictions, probabilities = [], []
for text in texts:
langs = self.lang_pipeline(input_ids)
# [{'label': 'fr', 'confidence': 99.87}]
if len(langs) > 0:
print(f"Detected languages: {langs}")
lang = langs['language']
logger.info(f"Detected language: {lang}")
else:
lang = "other"
logger.warning("Language detection failed, using 'other' as default.")
if lang not in self.ocrqa_assessors:
logger.warning(f"Language '{lang}' not found in bin_filename, using 'other' as default.")
lang = "other"
# Process the text using the selected filter
result = filter_text(text, self.ocrqa_assessors[lang])
known_count = len(result["known"])
unknown_count = len(result["unknown"])
# Compute quality score percentage
score = (known_count / (known_count + unknown_count + 0.000001)) # * 100
predictions.append(score)
return predictions
@property
def device(self):
return next(self.parameters()).device
@classmethod
def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
# Manually create the config
config = ImpressoConfig(**kwargs)
# Pass the manually created config to the class
model = cls(config)
return model
|