Delete tokenizer.py
Browse files- tokenizer.py +0 -180
tokenizer.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import List, Optional, Union
|
6 |
-
|
7 |
-
from tokenizers import Tokenizer as BaseTokenizer
|
8 |
-
|
9 |
-
from .aliases import PathOrStr
|
10 |
-
from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection
|
11 |
-
from .exceptions import OLMoConfigurationError
|
12 |
-
|
13 |
-
__all__ = ["Tokenizer"]
|
14 |
-
|
15 |
-
|
16 |
-
class Tokenizer:
|
17 |
-
"""
|
18 |
-
A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`.
|
19 |
-
|
20 |
-
:param base_tokenizer: The :class:`tokenizers.Tokenizer` to use.
|
21 |
-
:param eos_token_id: The token ID corresponding to the "end-of-sentence" token.
|
22 |
-
:param truncate_to: Truncate when tokenizing to this number of token IDs.
|
23 |
-
:param truncate_direction: The direction to truncate in. "right" means truncate the tokens
|
24 |
-
on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null,
|
25 |
-
this setting has no effect.
|
26 |
-
"""
|
27 |
-
|
28 |
-
def __init__(
|
29 |
-
self,
|
30 |
-
base_tokenizer: BaseTokenizer,
|
31 |
-
eos_token_id: int,
|
32 |
-
pad_token_id: Optional[int] = None,
|
33 |
-
truncate_to: Optional[int] = None,
|
34 |
-
truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right,
|
35 |
-
):
|
36 |
-
self.base_tokenizer = base_tokenizer
|
37 |
-
self.base_tokenizer.no_truncation()
|
38 |
-
self.eos_token_id = eos_token_id
|
39 |
-
self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id
|
40 |
-
self.truncate_to = truncate_to
|
41 |
-
self.truncate_direction = TruncationDirection(truncate_direction)
|
42 |
-
|
43 |
-
@property
|
44 |
-
def vocab_size(self) -> int:
|
45 |
-
return self.base_tokenizer.get_vocab_size()
|
46 |
-
|
47 |
-
@property
|
48 |
-
def eos_token(self) -> str:
|
49 |
-
return self.decode([self.eos_token_id], skip_special_tokens=False)
|
50 |
-
|
51 |
-
@property
|
52 |
-
def pad_token(self) -> str:
|
53 |
-
return self.decode([self.pad_token_id], skip_special_tokens=False)
|
54 |
-
|
55 |
-
@classmethod
|
56 |
-
def from_train_config(cls, config: TrainConfig) -> Tokenizer:
|
57 |
-
tokenizer_identifier = config.tokenizer.identifier
|
58 |
-
if Path(tokenizer_identifier).is_file():
|
59 |
-
tokenizer = cls.from_file(
|
60 |
-
tokenizer_identifier,
|
61 |
-
eos_token_id=config.model.eos_token_id,
|
62 |
-
pad_token_id=config.model.pad_token_id,
|
63 |
-
)
|
64 |
-
else:
|
65 |
-
tokenizer = cls.from_pretrained(
|
66 |
-
tokenizer_identifier,
|
67 |
-
eos_token_id=config.model.eos_token_id,
|
68 |
-
pad_token_id=config.model.pad_token_id,
|
69 |
-
)
|
70 |
-
if config.model.vocab_size != tokenizer.vocab_size:
|
71 |
-
raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
|
72 |
-
return tokenizer
|
73 |
-
|
74 |
-
@classmethod
|
75 |
-
def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer:
|
76 |
-
"""
|
77 |
-
Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub.
|
78 |
-
|
79 |
-
:param identifier: The identifier of a model on the Hub that contains a
|
80 |
-
``tokenizer.json`` file.
|
81 |
-
:param kwargs: Other key word arguments passed to :class:`Tokenizer`.
|
82 |
-
"""
|
83 |
-
base_tokenizer = BaseTokenizer.from_pretrained(identifier)
|
84 |
-
eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
|
85 |
-
return cls(base_tokenizer, eos_token_id, **kwargs)
|
86 |
-
|
87 |
-
@classmethod
|
88 |
-
def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer:
|
89 |
-
"""
|
90 |
-
Initialize a tokenizer from a file.
|
91 |
-
|
92 |
-
You can create those files with ``BaseTokenizer.save()``.
|
93 |
-
|
94 |
-
:param filename: The name of a file containing a tokenizer specification.
|
95 |
-
:param kwargs: Other key word arguments passed to :class:`Tokenizer`.
|
96 |
-
"""
|
97 |
-
base_tokenizer = BaseTokenizer.from_file(filename)
|
98 |
-
eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
|
99 |
-
return cls(base_tokenizer, eos_token_id, **kwargs)
|
100 |
-
|
101 |
-
@classmethod
|
102 |
-
def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer:
|
103 |
-
"""
|
104 |
-
Load a tokenizer from a checkpoint.
|
105 |
-
"""
|
106 |
-
from cached_path import cached_path
|
107 |
-
|
108 |
-
# Load configs.
|
109 |
-
config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml"))
|
110 |
-
tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer")
|
111 |
-
model_config = ModelConfig.load(config_path, key="model")
|
112 |
-
|
113 |
-
# Initialize tokenizer and validate vocab size.
|
114 |
-
if Path(tokenizer_config.identifier).is_file():
|
115 |
-
tokenizer = cls.from_file(
|
116 |
-
tokenizer_config.identifier,
|
117 |
-
eos_token_id=model_config.eos_token_id,
|
118 |
-
pad_token_id=model_config.pad_token_id,
|
119 |
-
)
|
120 |
-
else:
|
121 |
-
tokenizer = cls.from_pretrained(
|
122 |
-
tokenizer_config.identifier,
|
123 |
-
eos_token_id=model_config.eos_token_id,
|
124 |
-
pad_token_id=model_config.pad_token_id,
|
125 |
-
)
|
126 |
-
if model_config.vocab_size != tokenizer.vocab_size:
|
127 |
-
raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
|
128 |
-
return tokenizer
|
129 |
-
|
130 |
-
def add_special_tokens(self, input_ids: List[int]) -> List[int]:
|
131 |
-
"""
|
132 |
-
Add special tokens in-place (if not already present) to the given token IDs.
|
133 |
-
"""
|
134 |
-
if not input_ids or input_ids[-1] != self.eos_token_id:
|
135 |
-
input_ids.append(self.eos_token_id)
|
136 |
-
return input_ids
|
137 |
-
|
138 |
-
def num_special_tokens_to_add(self, is_pair: bool = False) -> int:
|
139 |
-
return 2 if is_pair else 1
|
140 |
-
|
141 |
-
def _truncate(
|
142 |
-
self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection
|
143 |
-
) -> list[int]:
|
144 |
-
if truncate_to is None or len(input_ids) <= truncate_to:
|
145 |
-
return input_ids
|
146 |
-
elif direction == TruncationDirection.left:
|
147 |
-
return input_ids[len(input_ids) - truncate_to :]
|
148 |
-
else:
|
149 |
-
return input_ids[: -(len(input_ids) - truncate_to)]
|
150 |
-
|
151 |
-
def encode(self, input: str, add_special_tokens: bool = True) -> List[int]:
|
152 |
-
"""
|
153 |
-
Encode a string into token IDs.
|
154 |
-
"""
|
155 |
-
return self.encode_batch([input], add_special_tokens=add_special_tokens)[0]
|
156 |
-
|
157 |
-
def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]:
|
158 |
-
"""
|
159 |
-
Encode a batch of strings into token IDs.
|
160 |
-
"""
|
161 |
-
truncate_to = self.truncate_to
|
162 |
-
if truncate_to is not None and add_special_tokens:
|
163 |
-
truncate_to -= self.num_special_tokens_to_add(False)
|
164 |
-
|
165 |
-
batch_encoding = self.base_tokenizer.encode_batch(inputs)
|
166 |
-
|
167 |
-
all_input_ids = []
|
168 |
-
for encoding in batch_encoding:
|
169 |
-
input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction)
|
170 |
-
if add_special_tokens:
|
171 |
-
input_ids = self.add_special_tokens(input_ids)
|
172 |
-
all_input_ids.append(input_ids)
|
173 |
-
|
174 |
-
return all_input_ids
|
175 |
-
|
176 |
-
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
|
177 |
-
"""
|
178 |
-
Decode a list of token IDs to a string.
|
179 |
-
"""
|
180 |
-
return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|