Spaces:
Running
Running
from transformers import T5Tokenizer | |
from typing import Dict, List, Optional, Union | |
import os | |
import logging | |
logger = logging.getLogger(__name__) | |
class Byt5LangTokenizer(T5Tokenizer): | |
""" | |
Кастомный токенайзер для ByT5 моделей с поддержкой распознавания таблиц. | |
Используется для модели vikp/surya_table | |
""" | |
def __init__( | |
self, | |
vocab_file=None, | |
tokenizer_file=None, | |
eos_token="</s>", | |
unk_token="<unk>", | |
pad_token="<pad>", | |
extra_ids=100, | |
additional_special_tokens=None, | |
sp_model_kwargs=None, | |
**kwargs | |
): | |
super().__init__( | |
vocab_file=vocab_file, | |
tokenizer_file=tokenizer_file, | |
eos_token=eos_token, | |
unk_token=unk_token, | |
pad_token=pad_token, | |
extra_ids=extra_ids, | |
additional_special_tokens=additional_special_tokens, | |
sp_model_kwargs=sp_model_kwargs, | |
**kwargs | |
) | |
# Создаем byte_decoder — важно для ByT5 | |
self.byte_decoder = {i: bytes([i]) for i in range(256)} | |
# Добавляем специальные токены | |
self.special_tokens = { | |
eos_token: self.convert_token_to_id(eos_token), | |
unk_token: self.convert_token_to_id(unk_token), | |
pad_token: self.convert_token_to_id(pad_token), | |
} | |
# Реализуем отсутствующие атрибуты | |
self.special_tokens_encoder = self.special_tokens | |
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} | |
def vocab_size(self): | |
return 256 + self.num_special_tokens | |
def get_vocab(self) -> Dict[str, int]: | |
vocab = {chr(i): i for i in range(256)} | |
vocab.update(self.special_tokens_encoder) | |
return vocab | |
def _tokenize(self, text: str) -> List[Union[int, str]]: | |
return list(text.encode("utf-8")) | |
def _convert_token_to_id(self, token: Union[str, int]) -> int: | |
if isinstance(token, str): | |
if token in self.special_tokens_encoder: | |
return self.special_tokens_encoder[token] | |
else: | |
try: | |
return ord(token) | |
except TypeError: | |
return token | |
return token | |
def _convert_id_to_token(self, index: int) -> Union[str, int]: | |
if index in self.special_tokens_decoder: | |
return self.special_tokens_decoder[index] | |
else: | |
return chr(index) | |
def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str: | |
decoded = b"" | |
for token in tokens: | |
if isinstance(token, int): | |
decoded += bytes([token]) | |
else: | |
decoded += token.encode("utf-8") | |
return decoded.decode("utf-8", errors="replace") | |