zehui127 commited on
Commit
f7c5681
·
verified ·
1 Parent(s): dd2d94a

Delete tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +0 -180
tokenizer.py DELETED
@@ -1,180 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from pathlib import Path
5
- from typing import List, Optional, Union
6
-
7
- from tokenizers import Tokenizer as BaseTokenizer
8
-
9
- from .aliases import PathOrStr
10
- from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection
11
- from .exceptions import OLMoConfigurationError
12
-
13
- __all__ = ["Tokenizer"]
14
-
15
-
16
- class Tokenizer:
17
- """
18
- A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`.
19
-
20
- :param base_tokenizer: The :class:`tokenizers.Tokenizer` to use.
21
- :param eos_token_id: The token ID corresponding to the "end-of-sentence" token.
22
- :param truncate_to: Truncate when tokenizing to this number of token IDs.
23
- :param truncate_direction: The direction to truncate in. "right" means truncate the tokens
24
- on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null,
25
- this setting has no effect.
26
- """
27
-
28
- def __init__(
29
- self,
30
- base_tokenizer: BaseTokenizer,
31
- eos_token_id: int,
32
- pad_token_id: Optional[int] = None,
33
- truncate_to: Optional[int] = None,
34
- truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right,
35
- ):
36
- self.base_tokenizer = base_tokenizer
37
- self.base_tokenizer.no_truncation()
38
- self.eos_token_id = eos_token_id
39
- self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id
40
- self.truncate_to = truncate_to
41
- self.truncate_direction = TruncationDirection(truncate_direction)
42
-
43
- @property
44
- def vocab_size(self) -> int:
45
- return self.base_tokenizer.get_vocab_size()
46
-
47
- @property
48
- def eos_token(self) -> str:
49
- return self.decode([self.eos_token_id], skip_special_tokens=False)
50
-
51
- @property
52
- def pad_token(self) -> str:
53
- return self.decode([self.pad_token_id], skip_special_tokens=False)
54
-
55
- @classmethod
56
- def from_train_config(cls, config: TrainConfig) -> Tokenizer:
57
- tokenizer_identifier = config.tokenizer.identifier
58
- if Path(tokenizer_identifier).is_file():
59
- tokenizer = cls.from_file(
60
- tokenizer_identifier,
61
- eos_token_id=config.model.eos_token_id,
62
- pad_token_id=config.model.pad_token_id,
63
- )
64
- else:
65
- tokenizer = cls.from_pretrained(
66
- tokenizer_identifier,
67
- eos_token_id=config.model.eos_token_id,
68
- pad_token_id=config.model.pad_token_id,
69
- )
70
- if config.model.vocab_size != tokenizer.vocab_size:
71
- raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
72
- return tokenizer
73
-
74
- @classmethod
75
- def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer:
76
- """
77
- Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub.
78
-
79
- :param identifier: The identifier of a model on the Hub that contains a
80
- ``tokenizer.json`` file.
81
- :param kwargs: Other key word arguments passed to :class:`Tokenizer`.
82
- """
83
- base_tokenizer = BaseTokenizer.from_pretrained(identifier)
84
- eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
85
- return cls(base_tokenizer, eos_token_id, **kwargs)
86
-
87
- @classmethod
88
- def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer:
89
- """
90
- Initialize a tokenizer from a file.
91
-
92
- You can create those files with ``BaseTokenizer.save()``.
93
-
94
- :param filename: The name of a file containing a tokenizer specification.
95
- :param kwargs: Other key word arguments passed to :class:`Tokenizer`.
96
- """
97
- base_tokenizer = BaseTokenizer.from_file(filename)
98
- eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
99
- return cls(base_tokenizer, eos_token_id, **kwargs)
100
-
101
- @classmethod
102
- def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer:
103
- """
104
- Load a tokenizer from a checkpoint.
105
- """
106
- from cached_path import cached_path
107
-
108
- # Load configs.
109
- config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml"))
110
- tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer")
111
- model_config = ModelConfig.load(config_path, key="model")
112
-
113
- # Initialize tokenizer and validate vocab size.
114
- if Path(tokenizer_config.identifier).is_file():
115
- tokenizer = cls.from_file(
116
- tokenizer_config.identifier,
117
- eos_token_id=model_config.eos_token_id,
118
- pad_token_id=model_config.pad_token_id,
119
- )
120
- else:
121
- tokenizer = cls.from_pretrained(
122
- tokenizer_config.identifier,
123
- eos_token_id=model_config.eos_token_id,
124
- pad_token_id=model_config.pad_token_id,
125
- )
126
- if model_config.vocab_size != tokenizer.vocab_size:
127
- raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
128
- return tokenizer
129
-
130
- def add_special_tokens(self, input_ids: List[int]) -> List[int]:
131
- """
132
- Add special tokens in-place (if not already present) to the given token IDs.
133
- """
134
- if not input_ids or input_ids[-1] != self.eos_token_id:
135
- input_ids.append(self.eos_token_id)
136
- return input_ids
137
-
138
- def num_special_tokens_to_add(self, is_pair: bool = False) -> int:
139
- return 2 if is_pair else 1
140
-
141
- def _truncate(
142
- self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection
143
- ) -> list[int]:
144
- if truncate_to is None or len(input_ids) <= truncate_to:
145
- return input_ids
146
- elif direction == TruncationDirection.left:
147
- return input_ids[len(input_ids) - truncate_to :]
148
- else:
149
- return input_ids[: -(len(input_ids) - truncate_to)]
150
-
151
- def encode(self, input: str, add_special_tokens: bool = True) -> List[int]:
152
- """
153
- Encode a string into token IDs.
154
- """
155
- return self.encode_batch([input], add_special_tokens=add_special_tokens)[0]
156
-
157
- def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]:
158
- """
159
- Encode a batch of strings into token IDs.
160
- """
161
- truncate_to = self.truncate_to
162
- if truncate_to is not None and add_special_tokens:
163
- truncate_to -= self.num_special_tokens_to_add(False)
164
-
165
- batch_encoding = self.base_tokenizer.encode_batch(inputs)
166
-
167
- all_input_ids = []
168
- for encoding in batch_encoding:
169
- input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction)
170
- if add_special_tokens:
171
- input_ids = self.add_special_tokens(input_ids)
172
- all_input_ids.append(input_ids)
173
-
174
- return all_input_ids
175
-
176
- def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
177
- """
178
- Decode a list of token IDs to a string.
179
- """
180
- return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)