skladbot-free-ai / custom_tokenizers.py
Lyti4's picture
Update custom_tokenizers.py
0f1c5d2 verified
from transformers import T5Tokenizer
from typing import Dict, List, Optional, Union
import os
import logging
logger = logging.getLogger(__name__)
class Byt5LangTokenizer(T5Tokenizer):
"""
Кастомный токенайзер для ByT5 моделей с поддержкой распознавания таблиц.
Используется для модели vikp/surya_table
"""
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
extra_ids=100,
additional_special_tokens=None,
sp_model_kwargs=None,
**kwargs
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
extra_ids=extra_ids,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=sp_model_kwargs,
**kwargs
)
# Создаем byte_decoder — важно для ByT5
self.byte_decoder = {i: bytes([i]) for i in range(256)}
# Добавляем специальные токены
self.special_tokens = {
eos_token: self.convert_token_to_id(eos_token),
unk_token: self.convert_token_to_id(unk_token),
pad_token: self.convert_token_to_id(pad_token),
}
# Реализуем отсутствующие атрибуты
self.special_tokens_encoder = self.special_tokens
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
@property
def vocab_size(self):
return 256 + self.num_special_tokens
def get_vocab(self) -> Dict[str, int]:
vocab = {chr(i): i for i in range(256)}
vocab.update(self.special_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[Union[int, str]]:
return list(text.encode("utf-8"))
def _convert_token_to_id(self, token: Union[str, int]) -> int:
if isinstance(token, str):
if token in self.special_tokens_encoder:
return self.special_tokens_encoder[token]
else:
try:
return ord(token)
except TypeError:
return token
return token
def _convert_id_to_token(self, index: int) -> Union[str, int]:
if index in self.special_tokens_decoder:
return self.special_tokens_decoder[index]
else:
return chr(index)
def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str:
decoded = b""
for token in tokens:
if isinstance(token, int):
decoded += bytes([token])
else:
decoded += token.encode("utf-8")
return decoded.decode("utf-8", errors="replace")