zehui127 commited on
Commit
8cea9fb
·
verified ·
1 Parent(s): 369bb70

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from .model.config import *
4
+ from .model.model import *
5
+ from .model.tokenizer import *
6
+ from .model.configuration_olmo import OLMoConfig
7
+ from .model.modeling_olmo import OLMoForCausalLM
8
+ from .model.modeling_olmo import OLMoForSequenceCLS
9
+ from .model.tokenization_olmo_fast import OLMoTokenizerFast
10
+
11
+
12
+
13
+ def check_install(cuda: bool = False):
14
+ import torch
15
+
16
+ from .version import VERSION
17
+
18
+ if cuda:
19
+ assert torch.cuda.is_available(), "CUDA is not available!"
20
+ print("CUDA available")
21
+
22
+ print(f"OLMo v{VERSION} installed")
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_type": "swiglu",
3
+ "alibi": false,
4
+ "alibi_bias_max": 8.0,
5
+ "architectures": [
6
+ "OLMoModelForCausalLM"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "attention_layer_norm": false,
10
+ "attention_layer_norm_with_affine": false,
11
+ "bias_for_layer_norm": false,
12
+ "block_group_size": 1,
13
+ "block_type": "sequential",
14
+ "clip_qkv": null,
15
+ "d_model": 512,
16
+ "embedding_dropout": 0.0,
17
+ "embedding_size": 4096,
18
+ "eos_token_id": 3,
19
+ "flash_attention": false,
20
+ "include_bias": false,
21
+ "init_cutoff_factor": null,
22
+ "init_device": "meta",
23
+ "init_fn": "mitchell",
24
+ "init_std": 0.02,
25
+ "layer_norm_type": "rms",
26
+ "layer_norm_with_affine": true,
27
+ "max_sequence_length": 250,
28
+ "mlp_hidden_size": null,
29
+ "mlp_ratio": 8,
30
+ "model_type": "olmo-gfm",
31
+ "multi_query_attention": false,
32
+ "n_heads": 8,
33
+ "n_kv_heads": null,
34
+ "n_layers": 8,
35
+ "pad_token_id": 3,
36
+ "precision": "amp_bf16",
37
+ "residual_dropout": 0.0,
38
+ "rope": true,
39
+ "rope_full_precision": true,
40
+ "scale_logits": false,
41
+ "transformers_version": "4.47.1",
42
+ "use_cache": true,
43
+ "vocab_size": 4096,
44
+ "weight_tying": false
45
+ }
config.yaml ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: OLMO-250-40GB-700M-10-epoch
2
+ seed: 6198
3
+ epoch: null
4
+ dry_run: false
5
+ model:
6
+ d_model: 512
7
+ n_heads: 8
8
+ n_kv_heads: null
9
+ clip_qkv: null
10
+ n_layers: 8
11
+ mlp_ratio: 8
12
+ mlp_hidden_size: null
13
+ activation_type: swiglu
14
+ block_type: sequential
15
+ block_group_size: 1
16
+ alibi: false
17
+ alibi_bias_max: 8.0
18
+ rope: true
19
+ rope_full_precision: true
20
+ flash_attention: false
21
+ attention_dropout: 0.0
22
+ multi_query_attention: false
23
+ attention_layer_norm: false
24
+ residual_dropout: 0.0
25
+ embedding_dropout: 0.0
26
+ layer_norm_type: rms
27
+ layer_norm_with_affine: true
28
+ attention_layer_norm_with_affine: false
29
+ max_sequence_length: 250
30
+ include_bias: false
31
+ bias_for_layer_norm: false
32
+ scale_logits: false
33
+ vocab_size: 4096
34
+ embedding_size: 4096
35
+ weight_tying: false
36
+ eos_token_id: 3
37
+ pad_token_id: 3
38
+ init_device: meta
39
+ init_fn: mitchell
40
+ init_std: 0.02
41
+ init_cutoff_factor: null
42
+ precision: amp_bf16
43
+ optimizer:
44
+ name: adamw
45
+ learning_rate: 0.0006
46
+ weight_decay: 0.1
47
+ betas:
48
+ - 0.9
49
+ - 0.95
50
+ no_decay_norm_and_bias: null
51
+ decay_norm_and_bias: false
52
+ decay_embeddings: false
53
+ metrics_log_interval: 10
54
+ scheduler:
55
+ name: cosine_with_warmup
56
+ units: steps
57
+ t_warmup: 5000
58
+ t_max: null
59
+ alpha_f: 0.1
60
+ grad_clip_warmup_steps: null
61
+ grad_clip_warmup_factor: null
62
+ warmup_min_lr: null
63
+ data:
64
+ paths:
65
+ - /mnt/data/tokenized_data/train_input_ids_1.npy
66
+ - /mnt/data/tokenized_data/train_input_ids_2.npy
67
+ - /mnt/data/tokenized_data/train_input_ids_3.npy
68
+ - /mnt/data/tokenized_data/train_input_ids_4.npy
69
+ - /mnt/data/tokenized_data/train_input_ids_5.npy
70
+ - /mnt/data/tokenized_data/train_input_ids_6.npy
71
+ - /mnt/data/tokenized_data/train_input_ids_7.npy
72
+ - /mnt/data/tokenized_data/train_input_ids_8.npy
73
+ - /mnt/data/tokenized_data/train_input_ids_9.npy
74
+ - /mnt/data/tokenized_data/train_input_ids_10.npy
75
+ - /mnt/data/tokenized_data/train_input_ids_11.npy
76
+ - /mnt/data/tokenized_data/train_input_ids_12.npy
77
+ - /mnt/data/tokenized_data/train_input_ids_13.npy
78
+ - /mnt/data/tokenized_data/train_input_ids_14.npy
79
+ - /mnt/data/tokenized_data/train_input_ids_15.npy
80
+ - /mnt/data/tokenized_data/train_input_ids_16.npy
81
+ - /mnt/data/tokenized_data/train_input_ids_17.npy
82
+ - /mnt/data/tokenized_data/train_input_ids_18.npy
83
+ - /mnt/data/tokenized_data/train_input_ids_19.npy
84
+ - /mnt/data/tokenized_data/train_input_ids_20.npy
85
+ - /mnt/data/tokenized_data/train_input_ids_21.npy
86
+ - /mnt/data/tokenized_data/train_input_ids_22.npy
87
+ - /mnt/data/tokenized_data/train_input_ids_23.npy
88
+ - /mnt/data/tokenized_data/train_input_ids_24.npy
89
+ - /mnt/data/tokenized_data/train_input_ids_25.npy
90
+ - /mnt/data/tokenized_data/train_input_ids_26.npy
91
+ - /mnt/data/tokenized_data/train_input_ids_27.npy
92
+ - /mnt/data/tokenized_data/train_input_ids_28.npy
93
+ - /mnt/data/tokenized_data/train_input_ids_29.npy
94
+ - /mnt/data/tokenized_data/train_input_ids_30.npy
95
+ - /mnt/data/tokenized_data/train_input_ids_31.npy
96
+ - /mnt/data/tokenized_data/train_input_ids_32.npy
97
+ - /mnt/data/tokenized_data/train_input_ids_33.npy
98
+ - /mnt/data/tokenized_data/train_input_ids_34.npy
99
+ - /mnt/data/tokenized_data/train_input_ids_35.npy
100
+ - /mnt/data/tokenized_data/train_input_ids_36.npy
101
+ - /mnt/data/tokenized_data/train_input_ids_37.npy
102
+ - /mnt/data/tokenized_data/train_input_ids_38.npy
103
+ - /mnt/data/tokenized_data/train_input_ids_39.npy
104
+ - /mnt/data/tokenized_data/train_input_ids_40.npy
105
+ - /mnt/data/tokenized_data/train_input_ids_41.npy
106
+ - /mnt/data/tokenized_data/train_input_ids_42.npy
107
+ - /mnt/data/tokenized_data/train_input_ids_43.npy
108
+ - /mnt/data/tokenized_data/train_input_ids_44.npy
109
+ - /mnt/data/tokenized_data/train_input_ids_45.npy
110
+ - /mnt/data/tokenized_data/train_input_ids_46.npy
111
+ - /mnt/data/tokenized_data/train_input_ids_47.npy
112
+ - /mnt/data/tokenized_data/train_input_ids_48.npy
113
+ - /mnt/data/tokenized_data/train_input_ids_49.npy
114
+ - /mnt/data/tokenized_data/train_input_ids_50.npy
115
+ - /mnt/data/tokenized_data/train_input_ids_51.npy
116
+ - /mnt/data/tokenized_data/train_input_ids_52.npy
117
+ - /mnt/data/tokenized_data/train_input_ids_53.npy
118
+ - /mnt/data/tokenized_data/train_input_ids_54.npy
119
+ - /mnt/data/tokenized_data/train_input_ids_55.npy
120
+ - /mnt/data/tokenized_data/train_input_ids_56.npy
121
+ - /mnt/data/tokenized_data/train_input_ids_57.npy
122
+ - /mnt/data/tokenized_data/train_input_ids_58.npy
123
+ - /mnt/data/tokenized_data/train_input_ids_59.npy
124
+ - /mnt/data/tokenized_data/train_input_ids_60.npy
125
+ - /mnt/data/tokenized_data/train_input_ids_61.npy
126
+ - /mnt/data/tokenized_data/train_input_ids_62.npy
127
+ - /mnt/data/tokenized_data/train_input_ids_63.npy
128
+ - /mnt/data/tokenized_data/train_input_ids_64.npy
129
+ - /mnt/data/tokenized_data/train_input_ids_65.npy
130
+ - /mnt/data/tokenized_data/train_input_ids_66.npy
131
+ - /mnt/data/tokenized_data/train_input_ids_67.npy
132
+ - /mnt/data/tokenized_data/train_input_ids_68.npy
133
+ - /mnt/data/tokenized_data/train_input_ids_69.npy
134
+ - /mnt/data/tokenized_data/train_input_ids_70.npy
135
+ - /mnt/data/tokenized_data/train_input_ids_71.npy
136
+ - /mnt/data/tokenized_data/train_input_ids_72.npy
137
+ - /mnt/data/tokenized_data/train_input_ids_73.npy
138
+ - /mnt/data/tokenized_data/train_input_ids_74.npy
139
+ - /mnt/data/tokenized_data/val_input_ids_2.npy
140
+ - /mnt/data/tokenized_data/val_input_ids_3.npy
141
+ - /mnt/data/tokenized_data/val_input_ids_4.npy
142
+ datasets: null
143
+ label_mask_paths: null
144
+ pad_direction: right
145
+ generate_attention_mask: false
146
+ num_workers: 16
147
+ drop_last: true
148
+ pin_memory: true
149
+ prefetch_factor: 16
150
+ persistent_workers: true
151
+ timeout: 0
152
+ seed: null
153
+ restore_dataloader: true
154
+ fast_forward_batches: null
155
+ evaluators:
156
+ - label: human-chunk
157
+ type: lm
158
+ data:
159
+ paths: null
160
+ datasets:
161
+ dna-bert2-eval:
162
+ - /mnt/data/tokenized_data/val_input_ids_1.npy
163
+ label_mask_paths: null
164
+ pad_direction: right
165
+ generate_attention_mask: false
166
+ num_workers: 16
167
+ drop_last: true
168
+ pin_memory: false
169
+ prefetch_factor: null
170
+ persistent_workers: false
171
+ timeout: 0
172
+ seed: null
173
+ device_eval_batch_size: null
174
+ subset_num_batches: null
175
+ eval_interval: 10000
176
+ tokenizer:
177
+ identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
178
+ truncate_direction: right
179
+ save_folder: /mnt/data/pretrain_formal_60M
180
+ remote_save_folder: null
181
+ canceled_check_interval: 50
182
+ save_interval: 10000
183
+ save_interval_unsharded: 20000
184
+ save_interval_ephemeral: null
185
+ save_num_checkpoints_to_keep: 1
186
+ save_num_unsharded_checkpoints_to_keep: 10
187
+ save_overwrite: true
188
+ force_save_unsharded: false
189
+ no_pre_train_checkpoint: false
190
+ load_path: null
191
+ load_path_sharded_checkpointer: null
192
+ reset_optimizer_state: false
193
+ reset_trainer_state: false
194
+ sharded_checkpointer: torch_legacy
195
+ new_style_checkpoints: null
196
+ max_duration: 10ep
197
+ global_train_batch_size: 384
198
+ device_train_batch_size: 48
199
+ device_train_microbatch_size: 48
200
+ device_eval_batch_size: 48
201
+ eval_subset_num_batches: -1
202
+ eval_on_load: false
203
+ device_train_grad_accum: 1
204
+ max_grad_norm: 1.0
205
+ max_grad_norm_ratio: null
206
+ precision: amp_bf16
207
+ wandb:
208
+ project: olmo-dna
209
+ entity: zehui127-imperial-college-london
210
+ group: null
211
+ name: OLMO-250-40GB-700M-10-epoch
212
+ tags:
213
+ - watching
214
+ log_artifacts: false
215
+ rank_zero_only: true
216
+ log_interval: 1
217
+ speed_monitor:
218
+ window_size: 3
219
+ gpu_flops_available: null
220
+ console_log_interval: 1
221
+ gen1_gc_interval: 1
222
+ compile: null
223
+ fsdp:
224
+ use_orig_params: true
225
+ sharding_strategy: FULL_SHARD
226
+ wrapping_strategy: null
227
+ precision: mixed
228
+ softmax_auxiliary_loss: false
229
+ time_limit: 964000.0
230
+ extra_steps_after_cancel: 10
231
+ early_stopping_factor: null
232
+ save_data_indices: true
233
+ python_profiling: false
234
+ torch_profiling: false
235
+ stop_at: null
236
+ stop_after: null
237
+ activation_checkpointing: null
238
+ fused_loss: null
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:204b4ffc81fe15e3c58ef8462cf9336525c53dd3cd835fd82d8949e439a8367a
3
+ size 151049935
model/__pycache__/aliases.cpython-312.pyc ADDED
Binary file (316 Bytes). View file
 
model/__pycache__/beam_search.cpython-312.pyc ADDED
Binary file (45.5 kB). View file
 
model/__pycache__/config.cpython-312.pyc ADDED
Binary file (25.1 kB). View file
 
model/__pycache__/configuration_olmo.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
model/__pycache__/exceptions.cpython-312.pyc ADDED
Binary file (1.7 kB). View file
 
model/__pycache__/initialization.cpython-312.pyc ADDED
Binary file (5.09 kB). View file
 
model/__pycache__/model.cpython-312.pyc ADDED
Binary file (77 kB). View file
 
model/__pycache__/modeling_olmo.cpython-312.pyc ADDED
Binary file (21.4 kB). View file
 
model/__pycache__/tokenization_olmo_fast.cpython-312.pyc ADDED
Binary file (611 Bytes). View file
 
model/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (9.05 kB). View file
 
model/__pycache__/torch_util.cpython-312.pyc ADDED
Binary file (8.1 kB). View file
 
model/__pycache__/util.cpython-312.pyc ADDED
Binary file (34.8 kB). View file
 
model/aliases.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from os import PathLike
2
+ from typing import Union
3
+
4
+ __all__ = ["PathOrStr"]
5
+
6
+
7
+ PathOrStr = Union[str, PathLike]
model/beam_search.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a self-contained and flexible beam search implementation adapted from
3
+ AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
4
+ """
5
+
6
+ import copy
7
+ import warnings
8
+ from abc import abstractmethod
9
+ from inspect import signature
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
11
+
12
+ import torch
13
+
14
+ __all__ = [
15
+ "Sampler",
16
+ "DeterministicSampler",
17
+ "MultinomialSampler",
18
+ "TopKSampler",
19
+ "TopPSampler",
20
+ "GumbelSampler",
21
+ "FinalSequenceScorer",
22
+ "SequenceLogProbabilityScorer",
23
+ "LengthNormalizedSequenceLogProbabilityScorer",
24
+ "Constraint",
25
+ "RepeatedNGramBlockingConstraint",
26
+ "BeamSearch",
27
+ ]
28
+
29
+ StateType = Dict[str, torch.Tensor]
30
+ StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
31
+ StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
32
+
33
+ StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
34
+ """
35
+ The type of step function that can be passed to [`BeamSearch.search`](#search).
36
+
37
+ This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
38
+ or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
39
+ """
40
+
41
+ ConstraintStateType = List[List[Dict[str, Any]]]
42
+
43
+
44
+ class Sampler:
45
+ """
46
+ An abstract class that can be used to sample candidates (either nodes or beams)
47
+ within `BeamSearch`.
48
+
49
+ A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
50
+
51
+ `init_state()` takes three arguments:
52
+
53
+ - a tensor of starting log probs with shape `(batch_size,, num_classes)`,
54
+ - the batch size, an int,
55
+ - and the number of classes, also an int.
56
+
57
+ It returns a state dictionary with any state tensors needed for subsequent
58
+ calls to `sample_nodes()` and `sample_beams()`.
59
+
60
+ By default this method just returns an empty dictionary.
61
+
62
+ Both `sample_nodes()` and `sample_beams()` should take three arguments:
63
+
64
+ - tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
65
+ - an integer representing the number of samples to take for each example in the batch,
66
+ - and a state dictionary which could contain any tensors needed for the `Sampler` to keep
67
+ track of state.
68
+
69
+ For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
70
+ `num_examples = beam_size * per_node_beam_size`.
71
+
72
+ The return value should be a tuple containing:
73
+
74
+ - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
75
+ - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
76
+ - and the updated state dictionary.
77
+
78
+ A default implementation of `sample_beams` is provided, which just deterministically
79
+ picks the `k` examples with highest log probability.
80
+ """
81
+
82
+ def init_state(
83
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
84
+ ) -> StateType:
85
+ del start_class_log_probabilities, batch_size, num_classes
86
+ return {}
87
+
88
+ @abstractmethod
89
+ def sample_nodes(
90
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
91
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
92
+ raise NotImplementedError
93
+
94
+ def sample_beams(
95
+ self, log_probs: torch.Tensor, beam_size: int, state: StateType
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
97
+ del state
98
+ selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
99
+ return selected_log_probs, selected_indices, {}
100
+
101
+
102
+ class DeterministicSampler(Sampler):
103
+ """
104
+ A `Sampler` that just deterministically returns the `k` nodes or beams with highest
105
+ log probability.
106
+ """
107
+
108
+ def sample_nodes(
109
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
111
+ del state
112
+ selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
113
+ return selected_log_probs, selected_indices, {}
114
+
115
+
116
+ class MultinomialSampler(Sampler):
117
+ """
118
+ A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
119
+ in the default, non-deterministic way.
120
+
121
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
122
+ above 1.0 produces a flatter probability distribution.
123
+ :param with_replacement: Whether to sample with replacement.
124
+
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ temperature: float = 1.0,
130
+ with_replacement: bool = False,
131
+ ) -> None:
132
+ self.temperature = temperature
133
+ self.with_replacement = with_replacement
134
+
135
+ def sample_nodes(
136
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
137
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
138
+ if self.temperature != 1.0:
139
+ _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
140
+ else:
141
+ _probabilities = log_probs.exp()
142
+
143
+ selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
144
+
145
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
146
+
147
+
148
+ class TopKSampler(Sampler):
149
+ """
150
+ A `Sampler` which redistributes the probability mass function for nodes among the
151
+ top `k` choices, then samples from that subset after re-normalizing the probabilities.
152
+
153
+ Beams are sampled in the default, deterministic way.
154
+
155
+ :param k: The number of top choices to be selected from.
156
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
157
+ above 1.0 produces a flatter probability distribution.
158
+ :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ k: int = 1,
164
+ temperature: float = 1.0,
165
+ with_replacement: bool = False,
166
+ ):
167
+ self.k = k
168
+ self.temperature = temperature or 1.0
169
+ self.with_replacement = with_replacement
170
+
171
+ def sample_nodes(
172
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
173
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
174
+ if not per_node_beam_size <= self.k <= log_probs.size()[1]:
175
+ raise ValueError(
176
+ "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
177
+ )
178
+
179
+ # shape (both): (batch_size, k)
180
+ top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
181
+
182
+ # Apply temperature if necessary.
183
+ # shape: (batch_size, k)
184
+ if self.temperature != 1.0:
185
+ top_k_log_probs = top_k_log_probs / self.temperature
186
+
187
+ # Re-normalize the subset.
188
+ # shape: (batch_size, k)
189
+ normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
190
+
191
+ # Sample from the re-normalized subset.
192
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
193
+ # shape: (batch_size, per_node_beam_size)
194
+ sampled_indices = torch.multinomial(
195
+ normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
196
+ )
197
+
198
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
199
+ # shape: (batch_size, per_node_beam_size)
200
+ indices = top_k_indices.gather(-1, sampled_indices)
201
+
202
+ return log_probs.gather(1, indices), indices, state
203
+
204
+
205
+ class TopPSampler(Sampler):
206
+ """
207
+ A `Sampler` which redistributes the probability mass function for nodes among
208
+ the top choices with a cumulative probability of at least `p`, then samples from that subset
209
+ after re-normalizing the probabilities.
210
+
211
+ Beams are sampled in the default, deterministic way.
212
+
213
+ :param p:
214
+ The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
215
+ examples to sample from. If `with_replacement` is `False` and the number of possible samples is
216
+ insufficient to sample without replacement from when calling `sample_nodes`, then the top
217
+ `per_node_beam_size` examples will be chosen.
218
+ :param temperature:
219
+ A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
220
+ above 1.0 produces a flatter probability distribution.
221
+ :param with_replacement:
222
+ If set to `True`, samples will be selected with replacement from the top choices.
223
+
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ p: float = 0.9,
229
+ temperature: float = 1.0,
230
+ with_replacement: bool = False,
231
+ ):
232
+ if p < 0.0 or p > 1.0:
233
+ raise ValueError("p must be a positive float no greater than 1.0")
234
+ self.p = p
235
+ self.temperature = temperature or 1.0
236
+ self.with_replacement = with_replacement
237
+
238
+ def sample_nodes(
239
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
241
+ if not per_node_beam_size <= log_probs.size()[1]:
242
+ raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
243
+
244
+ # First apply temperature coefficient:
245
+ if self.temperature != 1.0:
246
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
247
+ else:
248
+ _log_probs = log_probs
249
+
250
+ # Sort the probabilities in descending order to then find cumulative sum
251
+ log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
252
+
253
+ # shape: (batch_size, num_classes)
254
+ probabilities_descending = log_probs_descending.exp()
255
+ probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
256
+
257
+ # Create a mask for filtering out probabilities that don't make the top `p`.
258
+ # shape: (batch_size, num_classes)
259
+ exclusion_mask = probabilities_summed >= self.p
260
+
261
+ # We want to include the first index where probabilities_summed >= p, so we shift over one.
262
+ exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
263
+ exclusion_mask[..., 0] = False
264
+
265
+ # Make sure there's at least `per_node_beam_size` options to be selected.
266
+ if not self.with_replacement:
267
+ exclusion_mask[..., :per_node_beam_size] = False
268
+
269
+ log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
270
+
271
+ # Now re-normalized the included log probs.
272
+ # shape: (batch_size, num_classes)
273
+ filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
274
+
275
+ # Sample from the re-normalized subset.
276
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
277
+ # shape: (batch_size, per_node_beam_size)
278
+ sampled_indices = torch.multinomial(
279
+ filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
280
+ )
281
+
282
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
283
+ # shape: (batch_size, per_node_beam_size)
284
+ selected_indices = sorting_indices.gather(-1, sampled_indices)
285
+
286
+ # Return (selected log probabilities, selected classes)
287
+ # shape: (len(log_probs),1) , (len(log_probs), 1)
288
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
289
+
290
+
291
+ class GumbelSampler(Sampler):
292
+ """
293
+ A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
294
+ [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
295
+ Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
296
+ (https://api.semanticscholar.org/CorpusID:76662039).
297
+
298
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
299
+ above 1.0 produces a flatter probability distribution.
300
+ """
301
+
302
+ def __init__(self, temperature: float = 1.0):
303
+ self.temperature = temperature
304
+
305
+ def init_state(
306
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
307
+ ) -> StateType:
308
+ # shape: (batch_size, num_classes)
309
+ zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
310
+
311
+ # shape: (batch_size, num_classes)
312
+ G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
313
+
314
+ return {"G_phi_S": G_phi_S}
315
+
316
+ def sample_nodes(
317
+ self,
318
+ log_probs: torch.Tensor,
319
+ per_node_beam_size: int,
320
+ state: StateType,
321
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
322
+ # First apply temperature coefficient:
323
+ # shape: (batch_size * beam_size, num_classes)
324
+ if self.temperature != 1.0:
325
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
326
+ else:
327
+ _log_probs = log_probs
328
+
329
+ # shape: (group_size,)
330
+ phi_S = state["phi_S"]
331
+
332
+ # shape: (group_size, num_classes)
333
+ phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
334
+
335
+ # shape: (group_size, num_classes)
336
+ phi_S_new = phi_S + _log_probs
337
+
338
+ # shape: (group_size, 1)
339
+ G_phi_S = state["G_phi_S"].unsqueeze(-1)
340
+
341
+ # shape: (group_size, num_classes)
342
+ G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
343
+
344
+ # Replace NaNs with very negative number.
345
+ # shape: (group_size, num_classes)
346
+ # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
347
+
348
+ # shape (both): (group_size, per_node_beam_size)
349
+ top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
350
+
351
+ # shape: (group_size, per_node_beam_size)
352
+ top_log_probs = log_probs.gather(1, top_indices)
353
+
354
+ return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
355
+
356
+ def sample_beams(
357
+ self,
358
+ log_probs: torch.Tensor,
359
+ beam_size: int,
360
+ state: StateType,
361
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
362
+ """
363
+ Returns the beams with the highest perturbed log probabilities.
364
+ """
365
+ # shape (log_probs): (batch_size, beam_size * per_node_beam_size)
366
+
367
+ batch_size = log_probs.size()[0]
368
+
369
+ # shape: (batch_size * beam_size, per_node_beam_size)
370
+ G_phi_S = state["G_phi_S"]
371
+
372
+ # shape: (batch_size, beam_size * per_node_beam_size)
373
+ G_phi_S = G_phi_S.reshape_as(log_probs)
374
+
375
+ # shape (both): (batch_size, beam_size)
376
+ G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
377
+
378
+ # shape: (batch_size, beam_size)
379
+ selected_log_probs = log_probs.gather(1, selected_indices)
380
+
381
+ # Now sort the selected beams by their true log prob.
382
+ # shape (all): (batch_size, beam_size)
383
+ selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
384
+ selected_indices = selected_indices.gather(1, sort_indices)
385
+ G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
386
+
387
+ # shape: (batch_size * beam_size,)
388
+ G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
389
+
390
+ # shape: (batch_size * beam_size,)
391
+ phi_S = selected_log_probs.reshape(batch_size * beam_size)
392
+
393
+ return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
394
+
395
+ def gumbel(self, phi) -> torch.Tensor:
396
+ """
397
+ Sample `Gumbel(phi)`.
398
+
399
+ `phi` should have shape `(batch_size, num_classes)`.
400
+ """
401
+ return -torch.log(-torch.log(torch.rand_like(phi))) + phi
402
+
403
+ def gumbel_with_max(self, phi, T) -> torch.Tensor:
404
+ """
405
+ Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
406
+
407
+ `phi` should have shape `(batch_size, num_classes)` and `T` should have
408
+ shape `(batch_size, 1)`.
409
+ """
410
+ # Shape: (batch_size, num_classes)
411
+ G_phi = self.gumbel(phi)
412
+
413
+ # Now we find the maximum from these samples.
414
+ # Shape: (batch_size, )
415
+ Z, _ = G_phi.max(dim=-1)
416
+
417
+ # Shape: (batch_size, num_classes)
418
+ v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
419
+
420
+ # Shape: (batch_size, num_classes)
421
+ return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
422
+
423
+
424
+ class FinalSequenceScorer:
425
+ """
426
+ An abstract class that can be used to score the final generated sequences found
427
+ by beam search. Given the predicted sequences and the corresponding log probabilities of
428
+ those sequences, the class calculates and returns the final score of the sequences.
429
+
430
+ The default implementation scores the sequences using the sum of the log probabilities of
431
+ the sequence, which is passed as input.
432
+ """
433
+
434
+ @abstractmethod
435
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
436
+ """
437
+ Score the final predictions found by beam search.
438
+ Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
439
+
440
+ :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
441
+ :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
442
+ of the log probabilities per token, with shape `(batch_size, beam_size)`.
443
+ :param end_index: The index of the end symbol.
444
+
445
+ """
446
+ raise NotImplementedError
447
+
448
+
449
+ class SequenceLogProbabilityScorer(FinalSequenceScorer):
450
+ """
451
+ A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
452
+ across the sequence's tokens.
453
+ """
454
+
455
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
456
+ del predictions, end_index
457
+ # The sum of the sequence log probabilities is the input parameter, so just
458
+ # return it.
459
+ return log_probabilities
460
+
461
+
462
+ class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
463
+ """
464
+ A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
465
+ tokens in the sequence. It optionally includes a length penalty which promotes
466
+ or demotes sequences based on their lengths. The final score for a sequence will
467
+ be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
468
+ here includes the end token.
469
+
470
+ :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
471
+ A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
472
+ """
473
+
474
+ def __init__(self, length_penalty: float = 1.0):
475
+ super().__init__()
476
+ self.length_penalty = length_penalty
477
+
478
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
479
+ # shape: (batch_size, beam_size)
480
+ lengths = (predictions != end_index).long().sum(dim=2)
481
+
482
+ # If the sequence ended during beam search, the `log_probabilities` will include
483
+ # the transition to the end token. Therefore, in such situations, `lengths` is
484
+ # actually off by 1. This corrects for that.
485
+ # shape: (batch_size, beam_size)
486
+ is_end_token = predictions[:, :, -1] == end_index
487
+ lengths += is_end_token.long()
488
+
489
+ # shape: (batch_size, beam_size)
490
+ average_log_probs = log_probabilities / (lengths**self.length_penalty)
491
+ return average_log_probs
492
+
493
+
494
+ class Constraint:
495
+ """
496
+ An abstract class that can be used to enforce constraints on the output predictions
497
+ by manipulating the class log probabilities during beam search.
498
+
499
+ A `Constraint` just has three methods that need to be implemented by subclasses:
500
+ `init_state()`, `apply()` and `_update_state()`.
501
+
502
+ `init_state()` takes one argument:
503
+
504
+ - the batch size, an int
505
+
506
+ It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
507
+ calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
508
+ Each inner list should be of length 1.
509
+
510
+ `apply()` takes two arguments:
511
+
512
+ - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
513
+ and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
514
+ - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
515
+ log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
516
+
517
+ The `apply()` method should return new `class_log_probabilities` that enforce the constraint
518
+ for this step of beam search. For instance, it may prevent a specific class from being selected by setting
519
+ the corresponding log probability to a negligible value such as `float("-inf")` or
520
+ `torch.finfo(class_log_probabilities.dtype).min`.
521
+
522
+ `_update_state()` takes two arguments:
523
+
524
+ - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
525
+ copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
526
+ directly edited in-place without affecting the others.
527
+ - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
528
+ step of beam search.
529
+
530
+ The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
531
+ length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
532
+
533
+ """
534
+
535
+ @abstractmethod
536
+ def init_state(
537
+ self,
538
+ batch_size: int,
539
+ ) -> ConstraintStateType:
540
+ raise NotImplementedError
541
+
542
+ @abstractmethod
543
+ def apply(
544
+ self,
545
+ state: ConstraintStateType,
546
+ class_log_probabilities: torch.Tensor,
547
+ ) -> torch.Tensor:
548
+ raise NotImplementedError
549
+
550
+ @staticmethod
551
+ def _copy_state(
552
+ state: ConstraintStateType,
553
+ batch_size: int,
554
+ beam_size: int,
555
+ last_backpointer: Optional[torch.Tensor] = None,
556
+ ) -> ConstraintStateType:
557
+ """
558
+ Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
559
+ is not appropriate for your constraint, you will need to implement the copying yourself.
560
+ """
561
+ new_state = []
562
+ for i in range(batch_size):
563
+ batch_state = []
564
+ for j in range(beam_size):
565
+ if last_backpointer is None:
566
+ # This is the first prediction, so the backpointer is 0
567
+ backpointer = 0
568
+ else:
569
+ backpointer = last_backpointer[i, j].item()
570
+ batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
571
+ new_state.append(batch_state)
572
+ return new_state
573
+
574
+ def update_state(
575
+ self,
576
+ state: ConstraintStateType,
577
+ last_prediction: torch.Tensor,
578
+ last_backpointer: Optional[torch.Tensor] = None,
579
+ ) -> ConstraintStateType:
580
+ batch_size, beam_size = last_prediction.size()
581
+ new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
582
+ return self._update_state(new_state, last_prediction)
583
+
584
+ @abstractmethod
585
+ def _update_state(
586
+ self,
587
+ state: ConstraintStateType,
588
+ last_prediction: torch.Tensor,
589
+ ) -> ConstraintStateType:
590
+ raise NotImplementedError
591
+
592
+
593
+ class RepeatedNGramBlockingConstraint(Constraint):
594
+ def __init__(self, ngram_size: int, **kwargs) -> None:
595
+ super().__init__(**kwargs)
596
+ self.ngram_size = ngram_size
597
+
598
+ def init_state(
599
+ self,
600
+ batch_size: int,
601
+ ) -> ConstraintStateType:
602
+ return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
603
+
604
+ def apply(
605
+ self,
606
+ state: ConstraintStateType,
607
+ class_log_probabilities: torch.Tensor,
608
+ ) -> torch.Tensor:
609
+ for i, batch in enumerate(state):
610
+ for j, beam in enumerate(batch):
611
+ current_prefix = tuple(beam["current_prefix"])
612
+ seen_ngrams = beam["seen_ngrams"]
613
+ try:
614
+ disallowed_indices = seen_ngrams[current_prefix]
615
+ class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
616
+ class_log_probabilities.dtype
617
+ ).min
618
+ except KeyError:
619
+ # We have not seen this prefix before, so there is no index
620
+ # that needs to be blocked
621
+ pass
622
+ return class_log_probabilities
623
+
624
+ def _update_state(
625
+ self,
626
+ state: ConstraintStateType,
627
+ last_prediction: torch.Tensor,
628
+ ) -> ConstraintStateType:
629
+ for i, batch in enumerate(state):
630
+ for j, beam in enumerate(batch):
631
+ prediction = last_prediction[i, j].item()
632
+ prefix = beam["current_prefix"]
633
+ seen_ngrams = beam["seen_ngrams"]
634
+
635
+ if len(prefix) == self.ngram_size - 1:
636
+ # This is a new ngram that we have to remember
637
+ if tuple(prefix) not in seen_ngrams:
638
+ seen_ngrams[tuple(prefix)] = []
639
+ seen_ngrams[tuple(prefix)].append(prediction)
640
+
641
+ # Create the new prefix, removing the oldest index if the prefix
642
+ # is too long
643
+ prefix.append(prediction)
644
+ if len(prefix) == self.ngram_size:
645
+ prefix.pop(0)
646
+ return state
647
+
648
+
649
+ class BeamSearch:
650
+ """
651
+ Implements the beam search algorithm for decoding the most likely sequences.
652
+
653
+ :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
654
+
655
+ :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
656
+ of the predicted sequences.
657
+
658
+ :param beam_size: The width of the beam used.
659
+
660
+ :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
661
+ If not given, this just defaults to `beam_size`. Setting this parameter
662
+ to a number smaller than `beam_size` may give better results, as it can introduce
663
+ more diversity into the search. See
664
+ [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
665
+ (https://api.semanticscholar.org/CorpusID:2229477).
666
+
667
+ :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
668
+ If not specified, `DeterministicSampler` will be used, which just takes the
669
+ `per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
670
+
671
+ Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
672
+ [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
673
+
674
+ :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
675
+ the predicted sequences. This does not include the start or end tokens. If `None`,
676
+ no minimum is enforced.
677
+
678
+ :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
679
+ The output from this module is what is returned by the `search` method. If not
680
+ specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
681
+ by the sum of the token log probabilities.
682
+
683
+ :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
684
+ provided, no constraints will be enforced.
685
+
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ end_index: int,
691
+ *,
692
+ max_steps: int = 50,
693
+ beam_size: int = 10,
694
+ per_node_beam_size: Optional[int] = None,
695
+ sampler: Optional[Sampler] = None,
696
+ min_steps: Optional[int] = None,
697
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
698
+ constraints: Optional[List[Constraint]] = None,
699
+ ) -> None:
700
+ if not max_steps > 0:
701
+ raise ValueError("max_steps must be positive")
702
+ if not beam_size > 0:
703
+ raise ValueError("beam_size must be positive")
704
+ if per_node_beam_size is not None and not per_node_beam_size > 0:
705
+ raise ValueError("per_node_beam_size must be positive")
706
+ if min_steps is not None:
707
+ if not min_steps >= 0:
708
+ raise ValueError("min_steps must be non-negative")
709
+ if not min_steps <= max_steps:
710
+ raise ValueError("min_steps must be less than or equal to max_steps")
711
+
712
+ self._end_index = end_index
713
+ self.max_steps = max_steps
714
+ self.beam_size = beam_size
715
+ self.per_node_beam_size = per_node_beam_size or beam_size
716
+ self.sampler = sampler or DeterministicSampler()
717
+ self.min_steps = min_steps or 0
718
+ self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
719
+ self.constraints = constraints or []
720
+
721
+ @staticmethod
722
+ def _reconstruct_sequences(predictions, backpointers):
723
+ # Reconstruct the sequences.
724
+ # shape: [(batch_size, beam_size, 1)]
725
+ reconstructed_predictions = [predictions[-1].unsqueeze(2)]
726
+
727
+ if not backpointers:
728
+ return reconstructed_predictions
729
+
730
+ # shape: (batch_size, beam_size)
731
+ cur_backpointers = backpointers[-1]
732
+
733
+ for timestep in range(len(predictions) - 2, 0, -1):
734
+ # shape: (batch_size, beam_size, 1)
735
+ cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
736
+
737
+ reconstructed_predictions.append(cur_preds)
738
+
739
+ # shape: (batch_size, beam_size)
740
+ cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
741
+
742
+ # shape: (batch_size, beam_size, 1)
743
+ final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
744
+
745
+ reconstructed_predictions.append(final_preds)
746
+
747
+ return reconstructed_predictions
748
+
749
+ def search(
750
+ self,
751
+ start_predictions: torch.Tensor,
752
+ start_state: StateType,
753
+ step: StepFunctionType,
754
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
755
+ """
756
+ Given a starting state and a step function, apply beam search to find the
757
+ most likely target sequences.
758
+
759
+ Returns a tuple of `(predictions, final_scores)`, where `predictions`
760
+ has shape `(batch_size, beam_size, max_steps)` and `final_scores`
761
+ has shape `(batch_size, beam_size)`.
762
+
763
+ .. note::
764
+ If your step function returns `-inf` for some log probabilities
765
+ (like if you're using a masked log-softmax) then some of the "best"
766
+ sequences returned may also have `-inf` log probability. Specifically
767
+ this happens when the beam size is smaller than the number of actions
768
+ with finite log probability (non-zero probability) returned by the step function.
769
+ Therefore if you're using a mask you may want to check the results from `search`
770
+ and potentially discard sequences with non-finite log probability.
771
+
772
+ :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
773
+ Usually the initial predictions are just the index of the "start" token
774
+ in the target vocabulary.
775
+
776
+ :param start_state: The initial state passed to the `step` function. Each value of the state dict
777
+ should be a tensor of shape `(batch_size, *)`, where `*` means any other
778
+ number of dimensions.
779
+
780
+ :param step: A function that is responsible for computing the next most likely tokens,
781
+ given the current state and the predictions from the last time step.
782
+ The function should accept two or three arguments:
783
+
784
+ - a tensor of shape `(group_size,)` or representing the index of the predicted
785
+ tokens from the last time step,
786
+ - the current state, a `StateType`, and
787
+ - optionally, the timestep, an `int`.
788
+
789
+ The `group_size` will be `batch_size * beam_size`, except in the initial
790
+ step, for which it will just be `batch_size`.
791
+
792
+ The function is expected to return a tuple, where the first element
793
+ is a tensor of shape `(group_size, vocab_size)` containing
794
+ the log probabilities of the tokens for the next step, and the second
795
+ element is the updated state. The tensor in the state should have shape
796
+ `(group_size, *)`, where `*` means any other number of dimensions.
797
+
798
+ """
799
+ step_signature = signature(step)
800
+ if len(step_signature.parameters) < 3:
801
+ # If the step function we're given does not take the time step argument, wrap it
802
+ # in one that does.
803
+ old_step = cast(StepFunctionTypeNoTimestep, step)
804
+
805
+ def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
806
+ del time_step
807
+ return old_step(last_predictions, state)
808
+
809
+ return self._search(start_predictions, start_state, new_step)
810
+ else:
811
+ return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
812
+
813
+ def _search(
814
+ self,
815
+ start_predictions: torch.Tensor,
816
+ start_state: StateType,
817
+ step: StepFunctionTypeWithTimestep,
818
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
819
+ batch_size = start_predictions.size()[0]
820
+
821
+ # List of (batch_size, beam_size) tensors. One for each time step. Does not
822
+ # include the start symbols, which are implicit.
823
+ predictions: List[torch.Tensor] = []
824
+
825
+ # List of (batch_size, beam_size) tensors. One for each time step. None for
826
+ # the first. Stores the index n for the parent prediction, i.e.
827
+ # predictions[t-1][i][n], that it came from.
828
+ backpointers: List[torch.Tensor] = []
829
+
830
+ constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
831
+
832
+ # Calculate the first timestep. This is done outside the main loop
833
+ # because we are going from a single decoder input (the output from the
834
+ # encoder) to the top `beam_size` decoder outputs. On the other hand,
835
+ # within the main loop we are going from the `beam_size` elements of the
836
+ # beam to `beam_size`^2 candidates from which we will select the top
837
+ # `beam_size` elements for the next iteration.
838
+ # shape: (batch_size, num_classes)
839
+ start_class_log_probabilities, state = step(start_predictions, start_state, 0)
840
+
841
+ num_classes = start_class_log_probabilities.size()[1]
842
+
843
+ # Make sure `per_node_beam_size` is not larger than `num_classes`.
844
+ if self.per_node_beam_size > num_classes:
845
+ raise ValueError(
846
+ f"Vocab size ({num_classes:d}) too small "
847
+ f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
848
+ f"Please decrease beam_size or per_node_beam_size."
849
+ )
850
+
851
+ sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
852
+
853
+ # Apply all constraints.
854
+ if self.constraints:
855
+ # shape: (batch_size, 1, num_classes)
856
+ expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
857
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
858
+ expanded_start_class_log_probabilities = constraint.apply(
859
+ constraint_state, expanded_start_class_log_probabilities
860
+ )
861
+ start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
862
+
863
+ # Prevent selecting the end symbol if there is any min_steps constraint
864
+ if self.min_steps >= 1:
865
+ start_class_log_probabilities[:, self._end_index] = torch.finfo(
866
+ start_class_log_probabilities.dtype
867
+ ).min
868
+
869
+ # Get the initial predicted classed and their log probabilities.
870
+ # shape: (batch_size, beam_size), (batch_size, beam_size)
871
+ (
872
+ start_top_log_probabilities,
873
+ start_predicted_classes,
874
+ sampler_state,
875
+ ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
876
+
877
+ if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
878
+ warnings.warn(
879
+ "Empty sequences predicted. You may want to increase the beam size or ensure "
880
+ "your step function is working properly.",
881
+ RuntimeWarning,
882
+ )
883
+ return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
884
+
885
+ # The log probabilities for the last time step.
886
+ # shape: (batch_size, beam_size)
887
+ last_log_probabilities = start_top_log_probabilities
888
+
889
+ # shape: [(batch_size, beam_size)]
890
+ predictions.append(start_predicted_classes)
891
+
892
+ # Log probability tensor that mandates that the end token is selected.
893
+ # shape: (batch_size * beam_size, num_classes)
894
+ log_probs_after_end = start_class_log_probabilities.new_full(
895
+ (batch_size * self.beam_size, num_classes),
896
+ torch.finfo(start_class_log_probabilities.dtype).min,
897
+ )
898
+ log_probs_after_end[:, self._end_index] = 0.0
899
+
900
+ # Set the same state for each element in the beam.
901
+ self._update_initial_state(state, batch_size)
902
+
903
+ for i, constraint in enumerate(self.constraints):
904
+ constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
905
+
906
+ for timestep in range(self.max_steps - 1):
907
+ # shape: (batch_size * beam_size,)
908
+ last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
909
+
910
+ # If every predicted token from the last step is `self._end_index`,
911
+ # then we can stop early.
912
+ if (last_predictions == self._end_index).all():
913
+ break
914
+ # Take a step. This get the predicted log probs of the next classes
915
+ # and updates the state.
916
+ # shape: (batch_size * beam_size, num_classes)
917
+ class_log_probabilities, state = step(last_predictions, state, timestep + 1)
918
+
919
+ # Apply all constraints.
920
+ if self.constraints:
921
+ # shape: (batch_size, beam_size, num_classes)
922
+ reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
923
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
924
+ reshaped_class_log_probabilities = constraint.apply(
925
+ constraint_state, reshaped_class_log_probabilities
926
+ )
927
+ # shape: (batch_size * beam_size, num_classes)
928
+ class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
929
+
930
+ # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
931
+ # of the sequence (because `timestep` is 0-indexed and we generated the first token
932
+ # before the for loop). Here we block the end index if the search is not allowed to
933
+ # terminate on this iteration.
934
+ if timestep + 2 <= self.min_steps:
935
+ class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
936
+
937
+ # shape: (batch_size * beam_size, num_classes)
938
+ last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
939
+ batch_size * self.beam_size, num_classes
940
+ )
941
+
942
+ # Here we are finding any beams where we predicted the end token in
943
+ # the previous timestep and replacing the distribution with a
944
+ # one-hot distribution, forcing the beam to predict the end token
945
+ # this timestep as well.
946
+ # shape: (batch_size * beam_size, num_classes)
947
+ cleaned_log_probabilities = torch.where(
948
+ last_predictions_expanded == self._end_index,
949
+ log_probs_after_end,
950
+ class_log_probabilities,
951
+ )
952
+
953
+ # shape (both): (batch_size * beam_size, per_node_beam_size)
954
+ top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
955
+ cleaned_log_probabilities, self.per_node_beam_size, sampler_state
956
+ )
957
+
958
+ # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
959
+ # so that we can add them to the current log probs for this timestep.
960
+ # This lets us maintain the log probability of each element on the beam.
961
+ # shape: (batch_size * beam_size, per_node_beam_size)
962
+ expanded_last_log_probabilities = (
963
+ last_log_probabilities.unsqueeze(2)
964
+ .expand(batch_size, self.beam_size, self.per_node_beam_size)
965
+ .reshape(batch_size * self.beam_size, self.per_node_beam_size)
966
+ )
967
+
968
+ # shape: (batch_size * beam_size, per_node_beam_size)
969
+ summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
970
+
971
+ # shape: (batch_size, beam_size * per_node_beam_size)
972
+ reshaped_summed = summed_top_log_probabilities.reshape(
973
+ batch_size, self.beam_size * self.per_node_beam_size
974
+ )
975
+
976
+ # shape: (batch_size, beam_size * per_node_beam_size)
977
+ reshaped_predicted_classes = predicted_classes.reshape(
978
+ batch_size, self.beam_size * self.per_node_beam_size
979
+ )
980
+
981
+ # Keep only the top `beam_size` beam indices.
982
+ # shape (both): (batch_size, beam_size)
983
+ (
984
+ restricted_beam_log_probs,
985
+ restricted_beam_indices,
986
+ sampler_state,
987
+ ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
988
+
989
+ # Use the beam indices to extract the corresponding classes.
990
+ # shape: (batch_size, beam_size)
991
+ restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
992
+
993
+ predictions.append(restricted_predicted_classes)
994
+
995
+ # shape: (batch_size, beam_size)
996
+ last_log_probabilities = restricted_beam_log_probs
997
+
998
+ # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
999
+ # indices with a common ancestor are grouped together. Hence
1000
+ # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
1001
+ # division as the tensor is a LongTensor.)
1002
+ # shape: (batch_size, beam_size)
1003
+ backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
1004
+ backpointers.append(backpointer)
1005
+
1006
+ # Keep only the pieces of the state tensors corresponding to the
1007
+ # ancestors created this iteration.
1008
+ self._update_state(state, backpointer)
1009
+
1010
+ for i, constraint in enumerate(self.constraints):
1011
+ constraint_states[i] = constraint.update_state(
1012
+ constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
1013
+ )
1014
+
1015
+ # Warn about "-inf" log probabilities if not using any constraints (negligible
1016
+ # log probabilities are expected when using constraints).
1017
+ if not self.constraints and (
1018
+ not torch.isfinite(last_log_probabilities).all()
1019
+ or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
1020
+ ):
1021
+ warnings.warn(
1022
+ "Negligible log probabilities encountered ('-inf' or equivalent). "
1023
+ "Some final sequences may not make sense. "
1024
+ "This can happen when the beam size is larger than the number of valid (non-zero "
1025
+ "probability) transitions that the step function produces.",
1026
+ RuntimeWarning,
1027
+ )
1028
+
1029
+ reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
1030
+
1031
+ # shape: (batch_size, beam_size, max_steps)
1032
+ all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
1033
+
1034
+ # Calculate the final sequence scores
1035
+ # shape: (batch_size, beam_size)
1036
+ final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
1037
+
1038
+ # Sort the sequences based on the final scores so the best scoring
1039
+ # sequence is at index 0
1040
+ sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
1041
+ sorted_all_predictions = torch.gather(
1042
+ all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
1043
+ )
1044
+
1045
+ return sorted_all_predictions, sorted_final_scores
1046
+
1047
+ def _update_initial_state(self, state: StateType, batch_size: int):
1048
+ """
1049
+ Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
1050
+ """
1051
+ for key, state_tensor in state.items():
1052
+ if state_tensor is None:
1053
+ continue
1054
+ # shape: (batch_size * beam_size, *)
1055
+ _, *last_dims = state_tensor.size()
1056
+ state[key] = (
1057
+ state_tensor.unsqueeze(1)
1058
+ .expand(batch_size, self.beam_size, *last_dims)
1059
+ .reshape(batch_size * self.beam_size, *last_dims)
1060
+ )
1061
+
1062
+ def _update_state(self, state: StateType, backpointer: torch.Tensor):
1063
+ batch_size = backpointer.size()[0]
1064
+
1065
+ for key, state_tensor in state.items():
1066
+ if state_tensor is None:
1067
+ continue
1068
+ _, *last_dims = state_tensor.size()
1069
+ # shape: (batch_size, beam_size, *)
1070
+ expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
1071
+ batch_size, self.beam_size, *last_dims
1072
+ )
1073
+ # shape: (batch_size * beam_size, *)
1074
+ state[key] = (
1075
+ state_tensor.reshape(batch_size, self.beam_size, *last_dims)
1076
+ .gather(1, expanded_backpointer)
1077
+ .reshape(batch_size * self.beam_size, *last_dims)
1078
+ )
model/checkpoint.py ADDED
@@ -0,0 +1,1732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import io
3
+ import logging
4
+ import pickle
5
+ import shutil
6
+ import traceback
7
+ from abc import ABCMeta, abstractmethod
8
+ from collections import defaultdict
9
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
10
+ from contextlib import contextmanager
11
+ from copy import deepcopy
12
+ from dataclasses import dataclass, field, replace
13
+ from functools import reduce
14
+ from multiprocessing import shared_memory
15
+ from pathlib import Path
16
+ from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.distributed.checkpoint as dist_cp
21
+ import torch.multiprocessing as mp
22
+ from packaging import version
23
+ from torch.distributed import _remote_device
24
+ from torch.distributed._shard._utils import narrow_tensor_by_index
25
+ from torch.distributed._shard.metadata import ShardMetadata
26
+ from torch.distributed._shard.sharded_tensor import ShardedTensor
27
+ from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
28
+ from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
29
+ from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
30
+ from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
31
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
32
+ from torch.distributed.fsdp import StateDictType
33
+ from torch.distributed.fsdp.api import (
34
+ FullOptimStateDictConfig,
35
+ FullStateDictConfig,
36
+ ShardedOptimStateDictConfig,
37
+ ShardedStateDictConfig,
38
+ )
39
+ from torch.futures import Future
40
+
41
+ try:
42
+ from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore
43
+ except ModuleNotFoundError:
44
+ from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore
45
+
46
+ from olmo import util
47
+
48
+ from .aliases import PathOrStr
49
+ from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
50
+ from .exceptions import OLMoCheckpointError
51
+ from .optim import Optimizer, fix_optim_state_dict
52
+ from .safetensors_util import safetensors_file_to_state_dict
53
+ from .torch_util import (
54
+ barrier,
55
+ gc_cuda,
56
+ get_fs_local_rank,
57
+ get_global_rank,
58
+ get_world_size,
59
+ )
60
+ from .util import (
61
+ _get_s3_client,
62
+ default_thread_count,
63
+ dir_is_empty,
64
+ get_bytes_range,
65
+ get_progress_bar,
66
+ resource_path,
67
+ upload,
68
+ wait_for,
69
+ )
70
+
71
+ __all__ = [
72
+ "save_fsdp_model_and_optim_state",
73
+ "load_fsdp_model_and_optim_state",
74
+ "load_fsdp_optim_state",
75
+ "save_state_dict",
76
+ "load_state_dict",
77
+ "load_model_state",
78
+ "RemoteFileSystemWriter",
79
+ "RemoteFileSystemReader",
80
+ "Checkpointer",
81
+ "FullCheckpointer",
82
+ "TorchNewStyleShardedCheckpointer",
83
+ "TorchLegacyShardedCheckpointer",
84
+ "LocalShardedCheckpointer",
85
+ "build_sharded_checkpointer",
86
+ ]
87
+
88
+
89
+ log = logging.getLogger(__name__)
90
+
91
+ MODEL_AND_OPTIM_FOLDER = "model_and_optim"
92
+
93
+
94
+ def save_fsdp_model_and_optim_state(
95
+ checkpoint_dir: PathOrStr,
96
+ fsdp_model: FSDP,
97
+ optim: Optimizer,
98
+ *,
99
+ upload_to: Optional[str] = None,
100
+ save_overwrite: bool = False,
101
+ ):
102
+ """
103
+ Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
104
+ functions. This should be used during distributed training and should be called by all ranks.
105
+
106
+ :param checkpoint_dir: The directory to save to.
107
+ :param fsdp_model: The FSDP model.
108
+ :param optim: The FSDP model's optimizer.
109
+ :param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
110
+ :param save_overwrite: Overwrite existing files.
111
+
112
+ :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
113
+ """
114
+ checkpoint_dir = Path(checkpoint_dir)
115
+ target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
116
+ if save_overwrite:
117
+ if get_fs_local_rank() == 0:
118
+ shutil.rmtree(target_dir, ignore_errors=True)
119
+ elif not dir_is_empty(target_dir):
120
+ raise FileExistsError(target_dir)
121
+ barrier()
122
+ if get_fs_local_rank() == 0:
123
+ target_dir.mkdir(exist_ok=True, parents=True)
124
+ barrier()
125
+ with FSDP.state_dict_type(
126
+ fsdp_model,
127
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
128
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
129
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
130
+ ):
131
+ model_and_optim_state = {
132
+ "model": fsdp_model.state_dict(),
133
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
134
+ }
135
+ dist_cp.save_state_dict(
136
+ model_and_optim_state,
137
+ RemoteFileSystemWriter(
138
+ target_dir,
139
+ upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
140
+ save_overwrite=save_overwrite,
141
+ ),
142
+ )
143
+
144
+
145
+ def load_fsdp_model_and_optim_state(
146
+ checkpoint_dir: PathOrStr,
147
+ fsdp_model: FSDP,
148
+ optim: Optimizer,
149
+ *,
150
+ local_cache: Optional[PathOrStr] = None,
151
+ load_optimizer_state: bool = True,
152
+ ):
153
+ """
154
+ Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
155
+ functions. This should be used during distributed training and should be called by all ranks.
156
+
157
+ :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
158
+ :param fsdp_model: The FSDP model.
159
+ :param optim: The FSDP model's optimizer.
160
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
161
+ remote "directory" but there might be a cached version of the same artifacts.
162
+ :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
163
+
164
+ :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
165
+ """
166
+ load_path = str(checkpoint_dir).rstrip("/")
167
+ local_cache = None if local_cache is None else Path(local_cache)
168
+ with FSDP.state_dict_type(
169
+ fsdp_model,
170
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
171
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
172
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
173
+ ):
174
+ # Load the model state dict in place.
175
+ log.info("Loading model state...")
176
+ model_state = {"model": fsdp_model.state_dict()}
177
+ dist_cp.load_state_dict(
178
+ model_state,
179
+ RemoteFileSystemReader(
180
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
181
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
182
+ ),
183
+ )
184
+ fsdp_model.load_state_dict(model_state["model"])
185
+
186
+ if not load_optimizer_state:
187
+ return
188
+
189
+ # Load optim state dict in place.
190
+ log.info("Loading sharded optimizer state...")
191
+ optim_state = load_sharded_optimizer_state_dict(
192
+ model_state_dict=model_state["model"],
193
+ optimizer_key="optim",
194
+ storage_reader=RemoteFileSystemReader(
195
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
196
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
197
+ ),
198
+ )
199
+ del model_state
200
+ gc_cuda()
201
+ load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])
202
+
203
+
204
+ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
205
+ log.info("Flattening sharded optimizer state...")
206
+ # NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
207
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
208
+ flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
209
+ else:
210
+ flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
211
+ del optim_state
212
+ gc.collect()
213
+ log.info("Loading flattened optimizer state...")
214
+ # Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
215
+ # which takes up unnecessary GPU memory.
216
+ for state in flattened_osd["state"].values():
217
+ for k in state.keys():
218
+ v = state[k]
219
+ if isinstance(v, torch.Tensor):
220
+ state[k] = v.to(device="cpu")
221
+ gc_cuda()
222
+ optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))
223
+
224
+
225
+ def save_state_dict(
226
+ checkpoint_dir: PathOrStr,
227
+ fname: str,
228
+ state_dict: Dict[str, Any],
229
+ *,
230
+ upload_to: Optional[str] = None,
231
+ save_overwrite: bool = False,
232
+ synchronize: bool = True,
233
+ ):
234
+ """
235
+ Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
236
+ This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
237
+ for each rank.
238
+
239
+ :param checkpoint_dir: The directory to save to.
240
+ :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
241
+ :param state_dict: The state dict to save.
242
+ :param upload_to: Optional, a remote "directory" to upload the file to.
243
+ :param save_overwrite: Overwrite existing files.
244
+ :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
245
+ this function from a single rank.
246
+
247
+ :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
248
+ """
249
+ checkpoint_dir = Path(checkpoint_dir)
250
+ target_path = checkpoint_dir / fname
251
+ if save_overwrite:
252
+ target_path.unlink(missing_ok=True)
253
+ elif target_path.is_file():
254
+ raise FileExistsError(target_path)
255
+ if synchronize:
256
+ barrier()
257
+ target_path.parent.mkdir(exist_ok=True, parents=True)
258
+ if synchronize:
259
+ barrier()
260
+ torch.save(state_dict, target_path)
261
+ if upload_to is not None:
262
+ upload_target = f"{upload_to.rstrip('/')}/{fname}"
263
+ log.info(f"Uploading {target_path} to {upload_target}...")
264
+ upload(target_path, upload_target, save_overwrite=save_overwrite)
265
+
266
+
267
+ def load_state_dict(
268
+ checkpoint_dir: PathOrStr,
269
+ fname: str,
270
+ *,
271
+ local_cache: Optional[PathOrStr] = None,
272
+ map_location: Optional[str] = None,
273
+ ):
274
+ """
275
+ Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
276
+ This can be used during distributed training or not.
277
+
278
+ :param checkpoint_dir: A local or remote checkpoint directory.
279
+ :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
280
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
281
+ remote "directory" but there might be a cached version of the same artifacts.
282
+
283
+ :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
284
+ """
285
+ if fname.endswith(".pt"):
286
+ # Try safetensors version first.
287
+ try:
288
+ path = resource_path(
289
+ str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
290
+ )
291
+ return safetensors_file_to_state_dict(path, map_location=map_location)
292
+ except FileNotFoundError:
293
+ pass
294
+
295
+ path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
296
+ return torch.load(path, map_location=map_location)
297
+
298
+
299
+ def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
300
+ """
301
+ Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
302
+ Note that ``model`` should not be wrapped with FSDP.
303
+ """
304
+ state_dict = {"model": model.state_dict()}
305
+ dist_cp.load_state_dict(
306
+ state_dict,
307
+ RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
308
+ no_dist=True,
309
+ )
310
+ model.load_state_dict(state_dict["model"])
311
+
312
+
313
+ class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
314
+ """
315
+ A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
316
+ directly to a cloud bucket when ``upload_to`` is specified.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ path: PathOrStr,
322
+ single_file_per_rank: bool = True,
323
+ sync_files: bool = True,
324
+ thread_count: Optional[int] = None,
325
+ per_thread_copy_ahead: int = 10_000_000,
326
+ upload_to: Optional[str] = None,
327
+ save_overwrite: bool = False,
328
+ ) -> None:
329
+ if thread_count is not None and thread_count <= 0:
330
+ raise ValueError("thread count must be at least 1")
331
+ super().__init__(
332
+ path,
333
+ single_file_per_rank=single_file_per_rank,
334
+ sync_files=sync_files,
335
+ # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
336
+ # returns because uploading big checkpoint files with multiple threads causes
337
+ # boto3 to fail in weird ways.
338
+ thread_count=thread_count or 1,
339
+ per_thread_copy_ahead=per_thread_copy_ahead,
340
+ )
341
+ self.upload_to = None if upload_to is None else upload_to.rstrip("/")
342
+ self.save_overwrite = save_overwrite
343
+
344
+ def write_data(
345
+ self,
346
+ plan: dist_cp.SavePlan,
347
+ planner: dist_cp.SavePlanner,
348
+ ) -> Future[List[WriteResult]]:
349
+ fut = super().write_data(plan, planner)
350
+ if self.upload_to is not None:
351
+ files_to_upload = set()
352
+ for write_result in fut.wait():
353
+ files_to_upload.add(write_result.storage_data.relative_path)
354
+
355
+ # Create the global S3 client up front to work around a threading issue in boto.
356
+ if self.upload_to.startswith("s3://"):
357
+ _get_s3_client("s3")
358
+ elif self.upload_to.startswith("r2://"):
359
+ _get_s3_client("r2")
360
+
361
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
362
+ futures = []
363
+ for fname in files_to_upload:
364
+ source = self.path / fname
365
+ target = f"{self.upload_to}/{fname}"
366
+ log.info(f"Uploading {source} to {target}...")
367
+ futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
368
+ for f in as_completed(futures):
369
+ try:
370
+ f.result()
371
+ except BaseException:
372
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
373
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
374
+ # sure we're raising a simple error type that can be pickled.
375
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
376
+ return fut
377
+
378
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
379
+ super().finish(metadata, results)
380
+ if self.upload_to is not None:
381
+ source = self.path / ".metadata"
382
+ target = f"{self.upload_to}/.metadata"
383
+ log.info(f"Uploading {source} to {target}...")
384
+ upload(source, target, save_overwrite=self.save_overwrite)
385
+
386
+
387
+ class RemoteFileSystemReader(dist_cp.StorageReader):
388
+ """
389
+ A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
390
+ that can read data directly from cloud storage as well as a local directory.
391
+ """
392
+
393
+ def __init__(
394
+ self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
395
+ ):
396
+ super().__init__()
397
+ if thread_count is not None and thread_count <= 0:
398
+ raise ValueError("thread count must be at least 1")
399
+ self.path = str(path).rstrip("/")
400
+ self.cache = None if local_cache is None else Path(local_cache)
401
+ self.thread_count = thread_count or default_thread_count()
402
+ self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
403
+ self._metadata: Optional[Metadata] = None
404
+
405
+ def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
406
+ if self.cache is not None and (path := self.cache / relative_path).is_file():
407
+ return get_bytes_range(path, offset, length)
408
+ else:
409
+ return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
410
+
411
+ def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
412
+ sinfo = self.storage_data[read_item.storage_index]
413
+ content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
414
+ return (read_item, content)
415
+
416
+ def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
417
+ # Create the global S3 client up front to work around a threading issue in boto.
418
+ if isinstance(self.path, str):
419
+ if self.path.startswith("s3://"):
420
+ _get_s3_client("s3")
421
+ elif self.path.startswith("r2://"):
422
+ _get_s3_client("r2")
423
+
424
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
425
+ read_item_content_futures = []
426
+ for read_item in plan.items:
427
+ read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
428
+ read_item_content_results = []
429
+ for f in as_completed(read_item_content_futures):
430
+ try:
431
+ read_item_content_results.append(f.result())
432
+ except BaseException:
433
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
434
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
435
+ # sure we're raising a simple error type that can be pickled.
436
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
437
+
438
+ # Modified from `FileSystemReader.read_data()`
439
+ for read_item, content in read_item_content_results:
440
+ bytes = io.BytesIO(content)
441
+ bytes.seek(0)
442
+ if read_item.type == LoadItemType.BYTE_IO:
443
+ planner.load_bytes(read_item, bytes)
444
+ else:
445
+ tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
446
+ tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
447
+ target_tensor = planner.resolve_tensor(read_item).detach()
448
+
449
+ assert (
450
+ target_tensor.size() == tensor.size()
451
+ ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
452
+ target_tensor.copy_(tensor)
453
+ planner.commit_tensor(read_item, target_tensor)
454
+
455
+ fut: Future = Future()
456
+ fut.set_result(None)
457
+ return fut
458
+
459
+ def read_metadata(self) -> Metadata:
460
+ if self._metadata is None:
461
+ with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
462
+ self._metadata = pickle.load(metadata_file)
463
+ return self._metadata
464
+
465
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
466
+ del is_coordinator
467
+ self.storage_data = metadata.storage_data
468
+ assert self.storage_data is not None
469
+
470
+ def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
471
+ return plan
472
+
473
+ def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
474
+ return global_plan
475
+
476
+
477
+ class Checkpointer(metaclass=ABCMeta):
478
+ def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
479
+ self.cfg = cfg
480
+ self.thread_count = thread_count or default_thread_count()
481
+
482
+ @abstractmethod
483
+ def save_checkpoint(
484
+ self,
485
+ dir: PathOrStr,
486
+ fsdp_model: FSDP,
487
+ optim: Optimizer,
488
+ train_state: Dict[str, Any],
489
+ *,
490
+ upload_to: Optional[str] = None,
491
+ ) -> None:
492
+ raise NotImplementedError
493
+
494
+ @abstractmethod
495
+ def restore_checkpoint(
496
+ self,
497
+ load_path: PathOrStr,
498
+ fsdp_model: FSDP,
499
+ optim: Optimizer,
500
+ *,
501
+ local_cache: Optional[PathOrStr] = None,
502
+ load_optimizer_state: bool = True,
503
+ ) -> Dict[str, Any]:
504
+ """
505
+ Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
506
+ """
507
+ raise NotImplementedError
508
+
509
+ def unshard_checkpoint(
510
+ self,
511
+ load_path: PathOrStr,
512
+ *,
513
+ local_cache: Optional[PathOrStr] = None,
514
+ load_optimizer_state: bool = True,
515
+ load_trainer_state: bool = True,
516
+ device: Optional[torch.device] = None,
517
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
518
+ """
519
+ Unshard a checkpoint.
520
+
521
+ Note this is not marked abstract because child classes are not required to implemented this.
522
+ """
523
+ del load_path, local_cache, load_optimizer_state, load_trainer_state, device
524
+ raise NotImplementedError
525
+
526
+ @contextmanager
527
+ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
528
+ # Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
529
+ checkpoint_dir = Path(dir)
530
+ if not dir_is_empty(checkpoint_dir):
531
+ if self.cfg.save_overwrite:
532
+ if get_fs_local_rank() == 0:
533
+ shutil.rmtree(checkpoint_dir, ignore_errors=True)
534
+ else:
535
+ raise FileExistsError(checkpoint_dir)
536
+ # No need to mkdir here since we'll directly replace the temporary directory with
537
+ # this directory below.
538
+ barrier()
539
+
540
+ # Prepare temporary directory. We don't have to be as careful here, we can
541
+ # just remove it if it already exists.
542
+ checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
543
+ if get_fs_local_rank() == 0:
544
+ shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
545
+ checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)
546
+
547
+ barrier()
548
+
549
+ # Yield temporary directory for `.save_checkpoint()` to use.
550
+ yield checkpoint_dir_tmp
551
+
552
+ barrier()
553
+
554
+ # Finally if all went well replace the temporary directory with the actual
555
+ # checkpoint directory.
556
+ if get_fs_local_rank() == 0:
557
+ # Replace temp directory with target checkpoint directory.
558
+ try:
559
+ checkpoint_dir_tmp.replace(checkpoint_dir)
560
+ except FileNotFoundError:
561
+ # Caught when another (file-system) local rank 0 has already replaced the tmp directory.
562
+ # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
563
+ # file-systems.
564
+ if not checkpoint_dir.exists():
565
+ raise
566
+
567
+ # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
568
+ # replacing the temp directory with the final directory from rank 0 might not be immediately
569
+ # realized in the file systems of the other ranks.
570
+ # So we wait here across all ranks until that final checkpoint directory is visible.
571
+ wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)
572
+
573
+ barrier()
574
+
575
+ def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
576
+ if get_global_rank() == 0:
577
+ log.info("Saving config...")
578
+ self.cfg.save(config_path := Path(dir) / "config.yaml")
579
+ if upload_to is not None:
580
+ upload_target = f"{upload_to}/config.yaml"
581
+ log.info(f"Uploading {config_path} to {upload_target}")
582
+ upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)
583
+
584
+
585
+ class FullCheckpointer(Checkpointer):
586
+ """
587
+ A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
588
+ """
589
+
590
+ def save_checkpoint(
591
+ self,
592
+ dir: PathOrStr,
593
+ fsdp_model: FSDP,
594
+ optim: Optimizer,
595
+ trainer_state: Dict[str, Any],
596
+ *,
597
+ upload_to: Optional[str] = None,
598
+ ) -> None:
599
+ with self._temporary_wd(dir) as checkpoint_dir:
600
+ with FSDP.state_dict_type(
601
+ fsdp_model,
602
+ state_dict_type=StateDictType.FULL_STATE_DICT,
603
+ state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
604
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
605
+ ):
606
+ # We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
607
+ # First the model state.
608
+ model_state_dict = fsdp_model.state_dict()
609
+ if get_global_rank() == 0:
610
+ log.info("Saving model state...")
611
+ save_state_dict(
612
+ checkpoint_dir,
613
+ "model.pt",
614
+ model_state_dict,
615
+ upload_to=upload_to,
616
+ save_overwrite=self.cfg.save_overwrite,
617
+ synchronize=False,
618
+ )
619
+ del model_state_dict
620
+ barrier()
621
+
622
+ # Then the optimizer state.
623
+ optim_state_dict = FSDP.optim_state_dict(fsdp_model, optim)
624
+ if get_global_rank() == 0:
625
+ log.info("Saving optim state...")
626
+ save_state_dict(
627
+ checkpoint_dir,
628
+ "optim.pt",
629
+ optim_state_dict,
630
+ upload_to=upload_to,
631
+ save_overwrite=self.cfg.save_overwrite,
632
+ synchronize=False,
633
+ )
634
+ del optim_state_dict
635
+ barrier()
636
+
637
+ # Save trainer state.
638
+ if get_global_rank() == 0:
639
+ log.info("Saving trainer state...")
640
+ save_state_dict(
641
+ checkpoint_dir,
642
+ "train.pt",
643
+ trainer_state,
644
+ upload_to=upload_to,
645
+ save_overwrite=self.cfg.save_overwrite,
646
+ synchronize=False,
647
+ )
648
+ # Save config.
649
+ self._save_config(checkpoint_dir, upload_to=upload_to)
650
+
651
+ def restore_checkpoint(
652
+ self,
653
+ load_path: PathOrStr,
654
+ fsdp_model: FSDP,
655
+ optim: Optimizer,
656
+ *,
657
+ local_cache: Optional[PathOrStr] = None,
658
+ load_optimizer_state: bool = True,
659
+ ) -> Dict[str, Any]:
660
+ with FSDP.state_dict_type(
661
+ fsdp_model,
662
+ state_dict_type=StateDictType.FULL_STATE_DICT,
663
+ state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
664
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
665
+ ):
666
+ with torch.no_grad():
667
+ # fill everything with NaN, so we can check afterwards that every parameter has been restored
668
+ for module_name, module in fsdp_model.named_modules():
669
+ if not isinstance(module, FSDP):
670
+ continue
671
+ for param in module.params:
672
+ param.fill_(torch.nan)
673
+
674
+ # restore params from checkpoint
675
+ state_dict_to_load = load_state_dict(
676
+ load_path, "model.pt", local_cache=local_cache, map_location="cpu"
677
+ )
678
+ (
679
+ state_dict_to_load,
680
+ og_keys_to_new,
681
+ ) = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)
682
+
683
+ for module_name, module in fsdp_model.named_modules():
684
+ if not isinstance(module, FSDP):
685
+ continue
686
+ for param in module.params:
687
+ assert param._is_flat_param
688
+ for fqn, spi in zip(param._fqns, param._shard_param_infos):
689
+ if not spi.in_shard:
690
+ continue
691
+ key = f"{module_name}.{fqn}"
692
+ key = key.replace("_fsdp_wrapped_module.", "")
693
+ key = key.lstrip(".")
694
+ t = state_dict_to_load[key]
695
+ t = t.flatten()
696
+ param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
697
+ t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
698
+ )
699
+
700
+ # make sure that every parameter has been restored
701
+ for module_name, module in fsdp_model.named_modules():
702
+ if not isinstance(module, FSDP):
703
+ continue
704
+ for param in module.params:
705
+ if torch.isnan(param).any():
706
+ raise ValueError(
707
+ f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
708
+ )
709
+
710
+ # Load optimizer state.
711
+ if load_optimizer_state:
712
+ optim_state_dict_to_load = load_state_dict(
713
+ load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
714
+ )
715
+ optim_state_dict_to_load = self._make_optim_state_dict_compatible(
716
+ optim_state_dict_to_load,
717
+ og_keys_to_new,
718
+ )
719
+ load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
720
+ del optim_state_dict_to_load
721
+
722
+ # Load other state.
723
+ try:
724
+ trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
725
+ except FileNotFoundError:
726
+ # for backwards compatibility
727
+ trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
728
+ barrier()
729
+ return trainer_state
730
+
731
+ def _make_optim_state_dict_compatible(
732
+ self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
733
+ ) -> Dict[str, Any]:
734
+ # This state dict comes in two forms: one where the state keys are integers and one where the
735
+ # keys are fully qualified parameter names. The latter case is easier to deal with here so we
736
+ # first transform the integer key form into the FQN key form.
737
+ if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
738
+ id_to_fqn: Dict[int, str] = {}
739
+ for group in optim_state_dict["param_groups"]:
740
+ new_param_names = []
741
+ for fqn, id in zip(group["param_names"], group["params"]):
742
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
743
+ id_to_fqn[id] = fqn
744
+ new_param_names.append(fqn)
745
+ group["param_names"] = new_param_names
746
+ group["params"] = new_param_names
747
+ for id in list(optim_state_dict["state"].keys()):
748
+ optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
749
+ else:
750
+ # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
751
+ for group in optim_state_dict["param_groups"]:
752
+ group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
753
+ group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
754
+ assert group["param_names"] == group["params"]
755
+ for key in list(optim_state_dict["state"].keys()):
756
+ optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
757
+ "state"
758
+ ].pop(key)
759
+
760
+ # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
761
+ # First fix param names in the state.
762
+ for og_key, new_keys in og_keys_to_new.items():
763
+ og_state = optim_state_dict["state"].pop(og_key, None)
764
+ if og_state is None:
765
+ continue
766
+ for i, new_key in enumerate(new_keys):
767
+ if i == len(new_keys) - 1:
768
+ optim_state_dict["state"][new_key] = og_state
769
+ else:
770
+ optim_state_dict["state"][new_key] = deepcopy(og_state)
771
+ # Now fix param names in the param groups.
772
+ for group in optim_state_dict["param_groups"]:
773
+ og_names = group["params"]
774
+ new_names = []
775
+ for og_key in og_names:
776
+ for new_key in og_keys_to_new[og_key]:
777
+ new_names.append(new_key)
778
+ group["params"] = new_names
779
+ group["param_names"] = new_names
780
+
781
+ return optim_state_dict
782
+
783
+ def load_checkpoint(
784
+ self,
785
+ load_path: PathOrStr,
786
+ *,
787
+ local_cache: Optional[PathOrStr] = None,
788
+ load_optimizer_state: bool = True,
789
+ device: Optional[torch.device] = None,
790
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
791
+ device = device if device is not None else torch.device("cpu")
792
+ model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) # type: ignore
793
+ optim_state = None
794
+ if load_optimizer_state:
795
+ optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) # type: ignore
796
+ return model_state, optim_state
797
+
798
+
799
+ class TorchNewStyleShardedCheckpointer(Checkpointer):
800
+ """
801
+ A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
802
+ """
803
+
804
+ def save_checkpoint(
805
+ self,
806
+ dir: PathOrStr,
807
+ fsdp_model: FSDP,
808
+ optim: Optimizer,
809
+ trainer_state: Dict[str, Any],
810
+ *,
811
+ upload_to: Optional[str] = None,
812
+ ) -> None:
813
+ with self._temporary_wd(dir) as checkpoint_dir:
814
+ # Save model and optim state.
815
+ save_fsdp_model_and_optim_state(
816
+ checkpoint_dir,
817
+ fsdp_model,
818
+ optim,
819
+ upload_to=upload_to,
820
+ save_overwrite=self.cfg.save_overwrite,
821
+ )
822
+
823
+ # Save trainer state.
824
+ log.info("Saving trainer state...")
825
+ save_state_dict(
826
+ checkpoint_dir,
827
+ f"train/rank{get_global_rank()}.pt",
828
+ trainer_state,
829
+ upload_to=upload_to,
830
+ save_overwrite=self.cfg.save_overwrite,
831
+ )
832
+
833
+ # Save config.
834
+ self._save_config(checkpoint_dir, upload_to=upload_to)
835
+
836
+ def restore_checkpoint(
837
+ self,
838
+ load_path: PathOrStr,
839
+ fsdp_model: FSDP,
840
+ optim: Optimizer,
841
+ *,
842
+ local_cache: Optional[PathOrStr] = None,
843
+ load_optimizer_state: bool = True,
844
+ ) -> Dict[str, Any]:
845
+ # Load model and optimizer state in place.
846
+ log.info("Loading model and optimizer state...")
847
+ load_fsdp_model_and_optim_state(
848
+ load_path,
849
+ fsdp_model,
850
+ optim,
851
+ local_cache=local_cache,
852
+ load_optimizer_state=load_optimizer_state,
853
+ )
854
+
855
+ # Load trainer state dict.
856
+ log.info("Loading trainer state...")
857
+ try:
858
+ trainer_state = load_state_dict(
859
+ load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
860
+ )
861
+ except FileNotFoundError:
862
+ # Fall back to rank 0 train state.
863
+ # This can happen when we're restoring a checkpoint with a different world size.
864
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
865
+ barrier()
866
+ return trainer_state
867
+
868
+
869
+ class TorchLegacyShardedCheckpointer(Checkpointer):
870
+ """
871
+ A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
872
+ and optim state.
873
+
874
+ The world size must be kept consistent when using this checkpointer.
875
+ """
876
+
877
+ def save_checkpoint(
878
+ self,
879
+ dir: PathOrStr,
880
+ fsdp_model: FSDP,
881
+ optim: Optimizer,
882
+ trainer_state: Dict[str, Any],
883
+ *,
884
+ upload_to: Optional[str] = None,
885
+ ) -> None:
886
+ with self._temporary_wd(dir) as checkpoint_dir:
887
+ with FSDP.state_dict_type(
888
+ fsdp_model,
889
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
890
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
891
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
892
+ ):
893
+ state_dict = {
894
+ "model": fsdp_model.state_dict(),
895
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
896
+ **trainer_state,
897
+ }
898
+ save_state_dict(
899
+ checkpoint_dir,
900
+ f"rank{get_global_rank()}.pt",
901
+ state_dict,
902
+ upload_to=upload_to,
903
+ save_overwrite=self.cfg.save_overwrite,
904
+ )
905
+
906
+ # Save config.
907
+ self._save_config(checkpoint_dir, upload_to=upload_to)
908
+
909
+ def restore_checkpoint(
910
+ self,
911
+ load_path: PathOrStr,
912
+ fsdp_model: FSDP,
913
+ optim: Optimizer,
914
+ *,
915
+ local_cache: Optional[PathOrStr] = None,
916
+ load_optimizer_state: bool = True,
917
+ ) -> Dict[str, Any]:
918
+ with FSDP.state_dict_type(
919
+ fsdp_model,
920
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
921
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
922
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
923
+ ):
924
+ # Deserialize state dict.
925
+ state_dict = load_state_dict(
926
+ load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
927
+ )
928
+
929
+ # Load model and optimizer state.
930
+ log.info("Loading model state...")
931
+ fsdp_model.load_state_dict(state_dict["model"])
932
+ del state_dict["model"]
933
+ if load_optimizer_state:
934
+ log.info("Loading optimizer state...")
935
+ load_fsdp_optim_state(fsdp_model, optim, state_dict["optim"])
936
+ del state_dict["optim"]
937
+
938
+ barrier()
939
+ return state_dict
940
+
941
+ def unshard_checkpoint(
942
+ self,
943
+ load_path: PathOrStr,
944
+ *,
945
+ local_cache: Optional[PathOrStr] = None,
946
+ load_optimizer_state: bool = True,
947
+ load_trainer_state: bool = True,
948
+ device: Optional[torch.device] = None,
949
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
950
+ assert local_cache is None, "this method currently only supports local files"
951
+ full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
952
+ model_state = full_state_dict.pop("model")
953
+ optim_state = full_state_dict.pop("optim")
954
+ return (
955
+ model_state,
956
+ optim_state if load_optimizer_state else None,
957
+ full_state_dict if load_trainer_state else None,
958
+ )
959
+
960
+ def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
961
+ key = tuple() if key is None else key
962
+ if isinstance(state, (list, tuple, set)):
963
+ for i, sub_state in enumerate(state):
964
+ self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
965
+ elif isinstance(state, dict):
966
+ for name in state.keys():
967
+ self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
968
+ elif isinstance(state, ShardedTensor):
969
+ self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
970
+ return
971
+ else:
972
+ return
973
+
974
+ def _get_shard_placement_and_rank_sizes(
975
+ self, shards_metadata: List[ShardMetadata], world_size: int
976
+ ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
977
+ def shard_size(shard_md):
978
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
979
+
980
+ rank_sizes = [0 for _ in range(world_size)]
981
+ shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
982
+ for shard_md in shards_metadata:
983
+ shard_rank = cast(_remote_device, shard_md.placement).rank()
984
+ assert shard_rank is not None
985
+ if shard_rank >= world_size:
986
+ raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")
987
+
988
+ shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
989
+ rank_sizes[shard_rank] += shard_size(shard_md)
990
+
991
+ return shard_placement, rank_sizes
992
+
993
+ def _copy_sharded_tensor_to_shared_mem(
994
+ self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
995
+ ) -> Any:
996
+ shard0_md = sharded_tensor.metadata()
997
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
998
+ shard0_md.shards_metadata, world_size
999
+ )
1000
+
1001
+ rank_size = rank_sizes[rank]
1002
+ assert rank_size >= 0
1003
+ if rank_size == 0:
1004
+ return
1005
+
1006
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1007
+ numpy_type = np.float32
1008
+
1009
+ sharded_memory_name = "-".join(key + (str(rank),))
1010
+
1011
+ shm = shared_memory.SharedMemory(
1012
+ create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
1013
+ )
1014
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1015
+
1016
+ for local_shard in sharded_tensor.local_shards():
1017
+ shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
1018
+ assert shard_rank == rank
1019
+
1020
+ src = local_shard.tensor.flatten()
1021
+ shard_offset = shard_placement[local_shard.metadata][1]
1022
+
1023
+ np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()
1024
+
1025
+ shm.close()
1026
+
1027
+ def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
1028
+ shard_number = int(shard_filepath.name[4:-3])
1029
+ log.info("Starting unsharding shard number %d to shared memory", shard_number)
1030
+
1031
+ with self._patch_sharded_tensor_load():
1032
+ shard = torch.load(shard_filepath, map_location="cpu")
1033
+ log.debug("Done loading shard number %d", shard_number)
1034
+
1035
+ self._copy_sharded_tensors_to_shared_mem(
1036
+ shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
1037
+ )
1038
+ log.info("Done unsharding shard number %d to shared memory", shard_number)
1039
+
1040
+ def _unshard_using_sharded_mem(
1041
+ self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
1042
+ ) -> Any:
1043
+ return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))
1044
+
1045
+ def _unshard_state_using_shared_mem(
1046
+ self, state: Any, world_size: int, device: torch.device, key: Tuple
1047
+ ) -> Any:
1048
+ if isinstance(state, (list, tuple, set)):
1049
+ return state.__class__(
1050
+ self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
1051
+ for i, sub_state in enumerate(state)
1052
+ )
1053
+ elif isinstance(state, dict):
1054
+ return {
1055
+ name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
1056
+ for name in state.keys()
1057
+ }
1058
+ elif isinstance(state, ShardedTensor):
1059
+ return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
1060
+ elif isinstance(state, torch.Tensor):
1061
+ return state.to(device=device)
1062
+ else:
1063
+ return state
1064
+
1065
+ def _unshard_tensor_using_shared_mem(
1066
+ self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
1067
+ ) -> torch.Tensor:
1068
+ shard0_md = sharded_tensor.metadata()
1069
+
1070
+ def shard_size(shard_md):
1071
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1072
+
1073
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
1074
+ shard0_md.shards_metadata, world_size
1075
+ )
1076
+
1077
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1078
+ numpy_type = np.float32
1079
+
1080
+ out = torch.empty(
1081
+ *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
1082
+ )
1083
+ dims = len(sharded_tensor.metadata().size)
1084
+ for shard_md, (rank, rank_offset) in shard_placement.items():
1085
+ if rank >= world_size:
1086
+ raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")
1087
+
1088
+ sharded_memory_name = "-".join(key + (str(rank),))
1089
+ shm = shared_memory.SharedMemory(name=sharded_memory_name)
1090
+
1091
+ rank_size = rank_sizes[rank]
1092
+ assert rank_size >= 0
1093
+ if rank_size == 0:
1094
+ continue
1095
+
1096
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1097
+
1098
+ tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
1099
+ tensor = tensor.view(shard_md.shard_sizes)
1100
+
1101
+ out_narrow_view = out
1102
+ for dim in range(dims):
1103
+ out_narrow_view = out_narrow_view.narrow(
1104
+ dim,
1105
+ shard_md.shard_offsets[dim],
1106
+ shard_md.shard_sizes[dim],
1107
+ )
1108
+
1109
+ out_narrow_view.copy_(tensor)
1110
+
1111
+ shm.close()
1112
+ shm.unlink()
1113
+
1114
+ return out
1115
+
1116
+ @contextmanager
1117
+ def _patch_sharded_tensor_load(self):
1118
+ """
1119
+ Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
1120
+ """
1121
+
1122
+ def _rebuild_from_type_v2_monkey(func, new_type, args, state):
1123
+ ret = func(*args)
1124
+ if type(ret) is not new_type:
1125
+ ret = ret.as_subclass(new_type)
1126
+
1127
+ # Shortcut the construction of ShardedTensor
1128
+ # This is in the top 5 of my worst hacks.
1129
+ if isinstance(ret, ShardedTensor):
1130
+ ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
1131
+ return ret
1132
+
1133
+ # The rest of this function ought to be in the top 5 of somebody else's worst hacks.
1134
+ # Tensor does define __setstate__ even though it doesn't define
1135
+ # __getstate__. So only use __setstate__ if it is NOT the one defined
1136
+ # on Tensor
1137
+ if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
1138
+ ret.__setstate__(state)
1139
+ else:
1140
+ ret = torch._utils._set_obj_state(ret, state)
1141
+ return ret
1142
+
1143
+ original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
1144
+ try:
1145
+ torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
1146
+ yield
1147
+ finally:
1148
+ torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2
1149
+
1150
+ def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
1151
+ """
1152
+ The current unsharding implementation consists of:
1153
+
1154
+ 1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
1155
+ 2. Loading 1 shard on the main process as a base unsharded object.
1156
+ 3. Using the sharded tensors in shared memory to populate the base unsharded object.
1157
+
1158
+ This implementation replaced a prior implementation that instead loaded
1159
+ all shards using threads, because that implementation turned out to
1160
+ be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
1161
+ The current implementation is slower than the old one in many scenarios,
1162
+ but is significantly faster in the above mentioned case (e.g. 30 minutes)
1163
+ if there are enough CPUs.
1164
+ """
1165
+
1166
+ input_dir = Path(input_dir)
1167
+ skip_keys = skip_keys or set()
1168
+
1169
+ shard_filepaths = list(input_dir.glob("rank*.pt"))
1170
+ world_size = len(shard_filepaths)
1171
+ if world_size == 0:
1172
+ raise RuntimeError("No shards found for unsharding")
1173
+
1174
+ log.info("Number of shards: %d", world_size)
1175
+ shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
1176
+ min_ram_required_estimate_gb = shard_size_gb * world_size
1177
+ log.info(
1178
+ "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
1179
+ )
1180
+
1181
+ log.info("Copying sharded tensors to shared memory using multiple processes")
1182
+ # Copy sharded data to shared memory using multiple processes, so this process can load
1183
+ # from memory rather than disk. We spawn a new process instead of forking since shared memory
1184
+ # appears to get deleted when forked processes end for some reason.
1185
+ executor = ProcessPoolExecutor(
1186
+ mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
1187
+ )
1188
+ futures = []
1189
+ for shard_filepath in shard_filepaths:
1190
+ shard_rank = int(shard_filepath.name[4:-3])
1191
+
1192
+ if shard_rank >= world_size:
1193
+ raise RuntimeError(
1194
+ f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
1195
+ )
1196
+
1197
+ futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))
1198
+
1199
+ for f in as_completed(futures):
1200
+ f.result()
1201
+ executor.shutdown()
1202
+
1203
+ log.info("Loading a shard on the main process to be unsharded state")
1204
+ with self._patch_sharded_tensor_load():
1205
+ state = torch.load(shard_filepaths[0], map_location="cpu")
1206
+
1207
+ for key in skip_keys:
1208
+ if key in state:
1209
+ del state[key]
1210
+
1211
+ log.info("Unsharding from %d shards ...", world_size)
1212
+ return self._unshard_using_sharded_mem(state, world_size, device, input_dir)
1213
+
1214
+
1215
+ @dataclass
1216
+ class _LocalShardedCheckpointerMetadata(BaseConfig):
1217
+ world_size: int = field(default_factory=get_world_size)
1218
+
1219
+
1220
+ @dataclass
1221
+ class _FlatParamShard:
1222
+ full_shape: torch.Size
1223
+ shard_offsets: Tuple[int, int]
1224
+ shard_data: Optional[torch.Tensor]
1225
+
1226
+ def copy_into(self, full_tensor: torch.Tensor) -> None:
1227
+ assert self.shard_data is not None
1228
+ full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
1229
+ assert self.shard_data.shape == full_tensor_shard_view.shape
1230
+ full_tensor_shard_view.copy_(self.shard_data)
1231
+
1232
+
1233
+ class LocalShardedCheckpointer(Checkpointer):
1234
+ """
1235
+ A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
1236
+ The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.
1237
+
1238
+ The world size must be kept consistent when using this checkpointer. However, you can easily
1239
+ reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
1240
+ using :meth:`unshard_checkpoint()` (no distributed initialization required).
1241
+ """
1242
+
1243
+ # These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
1244
+ _FLAT_PARAM_METADATA_TO_SAVE = (
1245
+ "_fqns",
1246
+ "_shard_param_offsets",
1247
+ "_shard_indices",
1248
+ "_numels",
1249
+ "_numels_with_padding",
1250
+ "_shapes",
1251
+ "_shard_numel_padded",
1252
+ "_shard_param_infos",
1253
+ )
1254
+
1255
+ def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
1256
+ """
1257
+ Returns a list of FSDP modules with their FQN.
1258
+ """
1259
+ modules = []
1260
+ for name, module in fsdp_model.named_modules():
1261
+ if isinstance(module, FSDP):
1262
+ modules.append((name, module))
1263
+ return modules
1264
+
1265
+ def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
1266
+ from torch.distributed.fsdp._runtime_utils import _lazy_init
1267
+
1268
+ # TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
1269
+ # an FSDP state dict through the built-in methods.
1270
+ if torch.cuda.is_available():
1271
+ torch.cuda.synchronize()
1272
+ _lazy_init(fsdp_model, fsdp_model)
1273
+
1274
+ def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
1275
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
1276
+ return fsdp_model._handles # type: ignore
1277
+ elif version.parse(torch.__version__) < version.parse("2.3.0"):
1278
+ # Handle could be None if the FSDP wrapper doesn't manage any parameters.
1279
+ if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
1280
+ return [fsdp_model._handle] # type: ignore
1281
+ else:
1282
+ return []
1283
+ else:
1284
+ # Need to verify FSDP internals with newer versions.
1285
+ raise NotImplementedError
1286
+
1287
+ @torch.no_grad()
1288
+ def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
1289
+ self._prepare_fsdp_model(fsdp_model)
1290
+ module_data = []
1291
+ for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
1292
+ handle_data = []
1293
+ for handle in self._fsdp_handles(fsdp_module):
1294
+ data: Dict[str, Any] = {}
1295
+ # This is a `FlatParameter` instance.
1296
+ # See `torch.distributed.fsdp.flat_param` for the API.
1297
+ flat_param = handle.flat_param
1298
+ data["flat_param.data"] = flat_param.detach()
1299
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1300
+ if hasattr(flat_param, key):
1301
+ data[f"flat_param.{key}"] = getattr(flat_param, key)
1302
+ handle_data.append(data)
1303
+ module_data.append({"handles": handle_data, "name": module_fqn})
1304
+ return {"modules": module_data}
1305
+
1306
+ @torch.no_grad()
1307
+ def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
1308
+ """Load the state produced from `self._get_flat_param_state_to_save()`."""
1309
+ self._prepare_fsdp_model(fsdp_model)
1310
+ fsdp_modules = self._fsdp_modules(fsdp_model)
1311
+ assert len(model_state["modules"]) == len(fsdp_modules)
1312
+ for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
1313
+ handles = self._fsdp_handles(fsdp_module)
1314
+ assert len(handles) == len(module_data["handles"])
1315
+ for handle, data in zip(handles, module_data["handles"]):
1316
+ flat_param = handle.flat_param
1317
+ # Make sure metadata matches.
1318
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1319
+ if hasattr(flat_param, key):
1320
+ assert getattr(flat_param, key) == data[f"flat_param.{key}"]
1321
+ # Load the flat sharded data.
1322
+ flat_param.copy_(data["flat_param.data"])
1323
+
1324
+ def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
1325
+ if get_fs_local_rank() == 0:
1326
+ log.info("Saving metadata...")
1327
+ metadata = _LocalShardedCheckpointerMetadata()
1328
+ metadata.save(metadata_path := Path(dir) / "metadata.yaml")
1329
+ if upload_to is not None and get_global_rank() == 0:
1330
+ upload_target = f"{upload_to}/metadata.yaml"
1331
+ log.info(f"Uploading {metadata_path} to {upload_target}")
1332
+ upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)
1333
+
1334
+ def _load_metadata(
1335
+ self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
1336
+ ) -> _LocalShardedCheckpointerMetadata:
1337
+ metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
1338
+ return _LocalShardedCheckpointerMetadata.load(metadata_path)
1339
+
1340
+ def save_checkpoint(
1341
+ self,
1342
+ dir: PathOrStr,
1343
+ fsdp_model: FSDP,
1344
+ optim: Optimizer,
1345
+ trainer_state: Dict[str, Any],
1346
+ *,
1347
+ upload_to: Optional[str] = None,
1348
+ ) -> None:
1349
+ with self._temporary_wd(dir) as checkpoint_dir:
1350
+ # Gather local FSDP flat params data to save.
1351
+ # We also save some flat param metadata like the corresponding fully qualified names (fqns)
1352
+ # of each original parameter so we can validate that the sharding is the same when loading
1353
+ # one of these checkpoints.
1354
+ log.info("Saving local FSDP flat params data...")
1355
+ save_state_dict(
1356
+ checkpoint_dir,
1357
+ f"model/rank{get_global_rank()}.pt",
1358
+ self._get_flat_param_state_to_save(fsdp_model),
1359
+ upload_to=upload_to,
1360
+ save_overwrite=self.cfg.save_overwrite,
1361
+ )
1362
+
1363
+ # Save optimizer state.
1364
+ log.info("Saving local optimizer state...")
1365
+ save_state_dict(
1366
+ checkpoint_dir,
1367
+ f"optim/rank{get_global_rank()}.pt",
1368
+ optim.state_dict(),
1369
+ upload_to=upload_to,
1370
+ save_overwrite=self.cfg.save_overwrite,
1371
+ )
1372
+
1373
+ # Save trainer state.
1374
+ log.info("Saving trainer state...")
1375
+ save_state_dict(
1376
+ checkpoint_dir,
1377
+ f"train/rank{get_global_rank()}.pt",
1378
+ trainer_state,
1379
+ upload_to=upload_to,
1380
+ save_overwrite=self.cfg.save_overwrite,
1381
+ )
1382
+
1383
+ # Save metadata.
1384
+ self._save_metadata(checkpoint_dir, upload_to=upload_to)
1385
+
1386
+ # Save config. We do this last b/c the presence of a config in a remote checkpoint
1387
+ # "directory" indicates that the folder is valid, as a opposed to a partially
1388
+ # uploaded checkpoint directory that failed before completing.
1389
+ self._save_config(checkpoint_dir, upload_to=upload_to)
1390
+
1391
+ def restore_checkpoint(
1392
+ self,
1393
+ load_path: PathOrStr,
1394
+ fsdp_model: FSDP,
1395
+ optim: Optimizer,
1396
+ *,
1397
+ local_cache: Optional[PathOrStr] = None,
1398
+ load_optimizer_state: bool = True,
1399
+ ) -> Dict[str, Any]:
1400
+ # Load metadata and make sure checkpoint is compatible.
1401
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1402
+ assert metadata.world_size == get_world_size()
1403
+
1404
+ # Load local FSDP flat param data.
1405
+ log.info("Loading local FSDP flat params data...")
1406
+ model_state = load_state_dict(
1407
+ load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1408
+ )
1409
+ self._load_flat_param_state(fsdp_model, model_state)
1410
+ del model_state
1411
+
1412
+ # Load local optim state.
1413
+ if load_optimizer_state:
1414
+ log.info("Loading local optimizer state...")
1415
+ optim_state = load_state_dict(
1416
+ load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1417
+ )
1418
+ # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
1419
+ # in every rank, and keep this in the optimizer state. But this causes issues when loading the
1420
+ # state since torch sees the state is non-empty for some params which would normally be empty,
1421
+ # and then assumes it should have all of the other state tensors for that param, which is doesn't.
1422
+ # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
1423
+ # Not the end of the world but there's probably a better way around this without resetting
1424
+ # the metric.
1425
+ for param_id in list(optim_state["state"].keys()):
1426
+ state = optim_state["state"][param_id]
1427
+ if "grad_norm_exp_avg" in state:
1428
+ del state["grad_norm_exp_avg"]
1429
+ if len(state) == 0:
1430
+ del optim_state["state"][param_id]
1431
+ optim.load_state_dict(optim_state)
1432
+ del optim_state
1433
+
1434
+ # Load local trainer state.
1435
+ log.info("Loading local trainer state...")
1436
+ trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1437
+ barrier()
1438
+ return trainer_state
1439
+
1440
+ def _iter_flat_param_shards(
1441
+ self, model_state: Dict[str, Any]
1442
+ ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
1443
+ for module_data in model_state["modules"]:
1444
+ module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
1445
+ for handle in module_data["handles"]:
1446
+ flat_data = handle["flat_param.data"]
1447
+ if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
1448
+ # If there's padding in the flat param it should be on the right.
1449
+ assert (flat_data[-num_padding:] == 0).all()
1450
+ # NOTE: this changes depending on the torch version, but we don't do a version
1451
+ # check since we might be trying to unshard an old checkpoint that was stored
1452
+ # with a different torch version than we're currently running with.
1453
+ if "flat_param._shard_indices" in handle:
1454
+ # torch <=2.0.1
1455
+ param_start = handle["flat_param._shard_indices"][0]
1456
+ current_flat_index = 0
1457
+ for relative_fqn, full_shape, (offset_start, offset_end) in zip(
1458
+ handle["flat_param._fqns"][param_start:],
1459
+ handle["flat_param._shapes"][param_start:],
1460
+ handle["flat_param._shard_param_offsets"],
1461
+ ):
1462
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1463
+ numel_shard = offset_end - offset_start + 1
1464
+ flat_param_shard = _FlatParamShard(
1465
+ full_shape=full_shape,
1466
+ shard_offsets=(offset_start, offset_end),
1467
+ shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
1468
+ )
1469
+ current_flat_index += numel_shard
1470
+ yield root_fqn, flat_param_shard
1471
+ else:
1472
+ # torch >=2.1.0
1473
+ for relative_fqn, full_shape, shard_param_info in zip(
1474
+ handle["flat_param._fqns"],
1475
+ handle["flat_param._shapes"],
1476
+ handle["flat_param._shard_param_infos"],
1477
+ ):
1478
+ if not shard_param_info.in_shard:
1479
+ continue
1480
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1481
+ flat_param_shard = _FlatParamShard(
1482
+ full_shape=full_shape,
1483
+ shard_offsets=(
1484
+ shard_param_info.intra_param_start_idx,
1485
+ shard_param_info.intra_param_end_idx,
1486
+ ),
1487
+ shard_data=flat_data[
1488
+ shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
1489
+ + shard_param_info.numel_in_shard
1490
+ ],
1491
+ )
1492
+ yield root_fqn, flat_param_shard
1493
+
1494
+ def unshard_checkpoint(
1495
+ self,
1496
+ load_path: PathOrStr,
1497
+ *,
1498
+ local_cache: Optional[PathOrStr] = None,
1499
+ load_optimizer_state: bool = True,
1500
+ load_trainer_state: bool = True,
1501
+ device: Optional[torch.device] = None,
1502
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1503
+ device = device or torch.device("cpu")
1504
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1505
+
1506
+ # Gather paths model state, potentially downloading them.
1507
+ log.info("Gathering model state dicts...")
1508
+ model_state_paths = self._gather_state_dict_paths(
1509
+ load_path, "model", metadata.world_size, local_cache=local_cache
1510
+ )
1511
+
1512
+ # Load model state dicts one-by-one, materializing and populating the full parameters as we go.
1513
+ log.info("Materializing full parameters...")
1514
+ full_model_state: Dict[str, torch.Tensor] = {}
1515
+ # We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
1516
+ # the full optimizer state below without having to reload the model state dicts.
1517
+ flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
1518
+ for rank, path in enumerate(model_state_paths):
1519
+ log.info(f"Loading shards from rank {rank}...")
1520
+ model_state = torch.load(path, map_location="cpu")
1521
+ for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
1522
+ if root_fqn not in full_model_state:
1523
+ log.info(
1524
+ f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
1525
+ )
1526
+ assert flat_param_shard.shard_data is not None
1527
+ full_model_state[root_fqn] = torch.empty(
1528
+ flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
1529
+ )
1530
+ # Fill with NaNs so we can validate that the whole parameter has been populated
1531
+ # afterwards.
1532
+ full_model_state[root_fqn].fill_(torch.nan)
1533
+ # Copy over the local shard to the relevant part of the full parameter.
1534
+ full_param = full_model_state[root_fqn]
1535
+ log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
1536
+ flat_param_shard.copy_into(full_param)
1537
+ flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)
1538
+
1539
+ log.info("Validating full parameters...")
1540
+ for key, tensor in full_model_state.items():
1541
+ if torch.isnan(tensor).any():
1542
+ raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")
1543
+
1544
+ trainer_state: Optional[Dict[str, Any]] = None
1545
+ if load_trainer_state:
1546
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
1547
+
1548
+ if not load_optimizer_state:
1549
+ return full_model_state, None, trainer_state
1550
+
1551
+ log.info("Gathering optim state dicts...")
1552
+ optim_state_paths = self._gather_state_dict_paths(
1553
+ load_path, "optim", metadata.world_size, local_cache=local_cache
1554
+ )
1555
+
1556
+ log.info("Materializing full optim state...")
1557
+ full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
1558
+ fqn_to_id: Dict[str, int] = {}
1559
+ id_to_fqn: Dict[int, str] = {}
1560
+ for rank, path in enumerate(optim_state_paths):
1561
+ log.info(f"Loading sharded optim state from rank {rank}...")
1562
+ optim_state = torch.load(path, map_location="cpu")
1563
+
1564
+ # Initialize param groups.
1565
+ # We assume parameter groups are the same across all ranks.
1566
+ # The only thing that differs across ranks is the state for each local sharded param.
1567
+ if "param_groups" not in full_optim_state:
1568
+ full_optim_state["param_groups"] = optim_state["param_groups"]
1569
+ else:
1570
+ assert full_optim_state["param_groups"] == optim_state["param_groups"]
1571
+
1572
+ # Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
1573
+ if not fqn_to_id or not id_to_fqn:
1574
+ for group in full_optim_state["param_groups"]:
1575
+ for fqn, id in zip(group["param_names"], group["params"]):
1576
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
1577
+ fqn_to_id[fqn] = id
1578
+ id_to_fqn[id] = fqn
1579
+
1580
+ # Iterate over local shard state and copy into the full state.
1581
+ for id, shard_state in optim_state["state"].items():
1582
+ fqn = id_to_fqn[id]
1583
+ flat_param_shard = flat_params_data[rank].get(fqn) # type: ignore[assignment]
1584
+ full_state = full_optim_state["state"][id]
1585
+ for key, shard_value in shard_state.items():
1586
+ assert isinstance(shard_value, torch.Tensor)
1587
+ if shard_value.shape == torch.Size([]):
1588
+ # Add singleton tensors directly to full state. These should be the same across
1589
+ # all ranks.
1590
+ assert key in ("step", "grad_norm_exp_avg") # sanity check
1591
+ if key not in full_state:
1592
+ full_state[key] = shard_value.to(device)
1593
+ else:
1594
+ assert full_state[key] == shard_value
1595
+ else:
1596
+ # Otherwise we have a sharded param state.
1597
+ # If the corresponding full param state hasn't been materialized yet, do so now.
1598
+ assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
1599
+ if key not in full_state:
1600
+ log.info(
1601
+ f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
1602
+ )
1603
+ full_state[key] = torch.empty(
1604
+ flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
1605
+ )
1606
+ full_state_value = full_state[key]
1607
+
1608
+ # Copy over the local shard state to the relevant part of the full parameter state.
1609
+ log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
1610
+ replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)
1611
+
1612
+ # Lastly, clean up the parameter names in param groups.
1613
+ for group in full_optim_state["param_groups"]:
1614
+ group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]
1615
+
1616
+ return full_model_state, full_optim_state, trainer_state
1617
+
1618
+ def _get_state_dict_path(
1619
+ self,
1620
+ load_path: PathOrStr,
1621
+ state_dict_type: str,
1622
+ rank: int,
1623
+ *,
1624
+ local_cache: Optional[PathOrStr] = None,
1625
+ progress=None,
1626
+ ) -> Tuple[int, Path]:
1627
+ fname = f"{state_dict_type}/rank{rank}.pt"
1628
+ return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)
1629
+
1630
+ def _gather_state_dict_paths(
1631
+ self,
1632
+ load_path: PathOrStr,
1633
+ state_dict_type: str,
1634
+ world_size: int,
1635
+ *,
1636
+ local_cache: Optional[PathOrStr] = None,
1637
+ ) -> List[Path]:
1638
+ progress = get_progress_bar()
1639
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
1640
+ futures = []
1641
+ for rank in range(world_size):
1642
+ future = executor.submit(
1643
+ self._get_state_dict_path,
1644
+ load_path,
1645
+ state_dict_type,
1646
+ rank,
1647
+ local_cache=local_cache,
1648
+ progress=progress,
1649
+ )
1650
+ futures.append(future)
1651
+
1652
+ results: Dict[int, Path] = {}
1653
+ for future in as_completed(futures):
1654
+ rank, path = future.result()
1655
+ results[rank] = path
1656
+
1657
+ return [results[rank] for rank in range(world_size)]
1658
+
1659
+
1660
+ class OlmoCoreCheckpointer(Checkpointer):
1661
+ def save_checkpoint(
1662
+ self,
1663
+ dir: PathOrStr,
1664
+ fsdp_model: FSDP,
1665
+ optim: Optimizer,
1666
+ trainer_state: Dict[str, Any],
1667
+ *,
1668
+ upload_to: Optional[str] = None,
1669
+ ) -> None:
1670
+ from olmo_core.distributed.checkpoint import ( # type: ignore
1671
+ save_model_and_optim_state,
1672
+ )
1673
+
1674
+ with self._temporary_wd(dir) as checkpoint_dir:
1675
+ log.info("Saving model and optim state...")
1676
+ save_model_and_optim_state(checkpoint_dir, fsdp_model, optim, save_overwrite=self.cfg.save_overwrite)
1677
+ if upload_to is not None and get_fs_local_rank() == 0:
1678
+ for path in Path(checkpoint_dir).glob("**/*"):
1679
+ if not path.is_file():
1680
+ continue
1681
+ upload_target = f"{upload_to.rstrip('/')}/{path.relative_to(checkpoint_dir)}"
1682
+ log.info(f"Uploading {path} to {upload_target}...")
1683
+ upload(path, upload_target, save_overwrite=self.cfg.save_overwrite)
1684
+
1685
+ log.info("Saving trainer state...")
1686
+ save_state_dict(
1687
+ checkpoint_dir,
1688
+ f"train/rank{get_global_rank()}.pt",
1689
+ trainer_state,
1690
+ save_overwrite=self.cfg.save_overwrite,
1691
+ upload_to=upload_to,
1692
+ )
1693
+
1694
+ self._save_config(checkpoint_dir, upload_to=upload_to)
1695
+
1696
+ def restore_checkpoint(
1697
+ self,
1698
+ load_path: PathOrStr,
1699
+ fsdp_model: FSDP,
1700
+ optim: Optimizer,
1701
+ *,
1702
+ local_cache: Optional[PathOrStr] = None,
1703
+ load_optimizer_state: bool = True,
1704
+ ) -> Dict[str, Any]:
1705
+ from olmo_core.distributed.checkpoint import ( # type: ignore
1706
+ load_model_and_optim_state,
1707
+ )
1708
+
1709
+ log.info("Loading model and optim state...")
1710
+ load_model_and_optim_state(load_path, fsdp_model, optim if load_optimizer_state else None)
1711
+
1712
+ log.info("Loading trainer state...")
1713
+ trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1714
+
1715
+ barrier()
1716
+ return trainer_state
1717
+
1718
+
1719
+ def build_sharded_checkpointer(
1720
+ cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
1721
+ ) -> Checkpointer:
1722
+ name = name or cfg.sharded_checkpointer
1723
+ if name == ShardedCheckpointerType.torch_new:
1724
+ return TorchNewStyleShardedCheckpointer(cfg)
1725
+ elif name == ShardedCheckpointerType.torch_legacy:
1726
+ return TorchLegacyShardedCheckpointer(cfg)
1727
+ elif name == ShardedCheckpointerType.local:
1728
+ return LocalShardedCheckpointer(cfg)
1729
+ elif name == ShardedCheckpointerType.olmo_core:
1730
+ return OlmoCoreCheckpointer(cfg)
1731
+ else:
1732
+ raise NotImplementedError(name)
model/config.py ADDED
@@ -0,0 +1,1113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from glob import glob
5
+ from pathlib import Path
6
+ from typing import (
7
+ Any,
8
+ Dict,
9
+ Iterable,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ cast,
17
+ )
18
+
19
+ import torch
20
+ from omegaconf import DictConfig, ListConfig
21
+ from omegaconf import OmegaConf as om
22
+ from omegaconf.errors import OmegaConfBaseException
23
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
24
+
25
+ from .aliases import PathOrStr
26
+ from .exceptions import OLMoConfigurationError
27
+ from .util import StrEnum
28
+
29
+ __all__ = [
30
+ "ActivationType",
31
+ "ActivationCheckpointingStrategy",
32
+ "BlockType",
33
+ "LayerNormType",
34
+ "InitFnType",
35
+ "ModelConfig",
36
+ "OptimizerType",
37
+ "OptimizerConfig",
38
+ "SchedulerType",
39
+ "SchedulerConfig",
40
+ "DataConfig",
41
+ "EvaluatorConfig",
42
+ "TokenizerConfig",
43
+ "TrainConfig",
44
+ "PaddingDirection",
45
+ "TruncationDirection",
46
+ "SpeedMonitorConfig",
47
+ "WandbConfig",
48
+ "CompilerConfig",
49
+ "WandbConfig",
50
+ "FSDPPrecision",
51
+ "FSDPWrapStrategy",
52
+ "FSDPConfig",
53
+ "CheckpointType",
54
+ ]
55
+
56
+ C = TypeVar("C", bound="BaseConfig")
57
+ D = TypeVar("D", bound="DictConfig|ListConfig")
58
+
59
+
60
+ class BaseConfig:
61
+ @classmethod
62
+ def _register_resolvers(cls, validate_paths: bool = True):
63
+ # Expands path globs into a list.
64
+ def path_glob(*paths) -> List[str]:
65
+ out = []
66
+ for path in paths:
67
+ matches = sorted(glob(path))
68
+ if not matches and validate_paths:
69
+ raise FileNotFoundError(f"{path} does not match any files or dirs")
70
+ out.extend(matches)
71
+ return out
72
+
73
+ # Chooses the first path in the arguments that exists.
74
+ def path_choose(*paths) -> str:
75
+ from .util import is_url
76
+
77
+ for path in paths:
78
+ if is_url(path) or Path(path).exists():
79
+ return path
80
+ if validate_paths:
81
+ raise FileNotFoundError(", ".join(paths))
82
+ else:
83
+ return ""
84
+
85
+ # Finds the latest checkpoint in a folder.
86
+ def path_last_checkpoint(path) -> str:
87
+ from .util import find_latest_checkpoint
88
+
89
+ latest_checkpoint = find_latest_checkpoint(path)
90
+ if latest_checkpoint is None:
91
+ if validate_paths:
92
+ raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
93
+ else:
94
+ return ""
95
+ else:
96
+ return str(latest_checkpoint)
97
+
98
+ om.register_new_resolver("path.glob", path_glob, replace=True)
99
+ om.register_new_resolver("path.choose", path_choose, replace=True)
100
+ om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
101
+
102
+ @classmethod
103
+ def update_legacy_settings(cls, config: D) -> D:
104
+ """
105
+ Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
106
+ """
107
+ return config
108
+
109
+ @classmethod
110
+ def new(cls: Type[C], **kwargs) -> C:
111
+ cls._register_resolvers()
112
+ conf = om.structured(cls)
113
+ try:
114
+ if kwargs:
115
+ conf = om.merge(conf, kwargs)
116
+ return cast(C, om.to_object(conf))
117
+ except OmegaConfBaseException as e:
118
+ raise OLMoConfigurationError(str(e))
119
+
120
+ @classmethod
121
+ def load(
122
+ cls: Type[C],
123
+ path: PathOrStr,
124
+ overrides: Optional[List[str]] = None,
125
+ key: Optional[str] = None,
126
+ validate_paths: bool = True,
127
+ ) -> C:
128
+ """Load from a YAML file."""
129
+ cls._register_resolvers(validate_paths=validate_paths)
130
+ schema = om.structured(cls)
131
+ try:
132
+ raw = om.load(str(path))
133
+ if key is not None:
134
+ raw = raw[key] # type: ignore
135
+ raw = cls.update_legacy_settings(raw)
136
+ conf = om.merge(schema, raw)
137
+ if overrides:
138
+ conf = om.merge(conf, om.from_dotlist(overrides))
139
+ return cast(C, om.to_object(conf))
140
+ except OmegaConfBaseException as e:
141
+ raise OLMoConfigurationError(str(e))
142
+
143
+ def save(self, path: PathOrStr) -> None:
144
+ """Save to a YAML file."""
145
+ om.save(config=self, f=str(path))
146
+
147
+ def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
148
+ out = asdict(self) # type: ignore
149
+ if exclude is not None:
150
+ for name in exclude:
151
+ if name in out:
152
+ del out[name]
153
+ return out
154
+
155
+
156
+ class LayerNormType(StrEnum):
157
+ default = "default"
158
+ """
159
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
160
+ """
161
+
162
+ low_precision = "low_precision"
163
+ """
164
+ A low-precision version of the default LayerNorm.
165
+ """
166
+
167
+ rms = "rms"
168
+ """
169
+ An RMSNorm implementation. When using ``torch.compile`` this is
170
+ probably the fastest implementation.
171
+ """
172
+
173
+
174
+ class ActivationType(StrEnum):
175
+ gelu = "gelu"
176
+ relu = "relu"
177
+ swiglu = "swiglu"
178
+
179
+
180
+ class BlockType(StrEnum):
181
+ sequential = "sequential"
182
+
183
+ llama = "llama"
184
+ """
185
+ A block similar to the sequential block with slightly different
186
+ implementations of operations like attention to imitate the behavior of Llama.
187
+ """
188
+
189
+
190
+ class InitFnType(StrEnum):
191
+ mitchell = "mitchell"
192
+ """
193
+ The strategy suggested to us by Mitchell Wortsman from UW.
194
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
195
+ on the size of the weights as well as the depth of the layer.
196
+ """
197
+
198
+ normal = "normal"
199
+ """
200
+ All weights are initialized from the same normal distribution.
201
+ """
202
+
203
+ kaiming_normal = "kaiming_normal"
204
+ """
205
+ All weights are initialized with the Kaiming method from a normal distribution.
206
+ Note this currently won't work with FSDP.
207
+ """
208
+
209
+ fan_in = "fan_in"
210
+ """
211
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
212
+ is the input dimensionality of the kernel.
213
+ """
214
+
215
+ full_megatron = "full_megatron"
216
+ """
217
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
218
+ """
219
+
220
+
221
+ @dataclass
222
+ class ModelConfig(BaseConfig):
223
+ """
224
+ OLMo (model) configuration.
225
+ """
226
+
227
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
228
+
229
+ d_model: int = 768
230
+ """
231
+ The hidden size of the model.
232
+ """
233
+
234
+ n_heads: int = 12
235
+ """
236
+ The number of self-attention heads.
237
+ """
238
+
239
+ n_kv_heads: Optional[int] = None
240
+ """
241
+ The number of heads to use for keys and values. Defaults to `n_heads`.
242
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
243
+ Set this to 1 for multi-query attention.
244
+ Set it to some in-between value for Llama2-style grouped query attention.
245
+ """
246
+
247
+ clip_qkv: Optional[float] = None
248
+ """
249
+ Clip QKV to this value when set.
250
+ """
251
+
252
+ n_layers: int = 12
253
+ """
254
+ The number of layers/blocks.
255
+ """
256
+
257
+ mlp_ratio: int = 4
258
+ """
259
+ The ratio of the inner MLP dimensionality to ``d_model``.
260
+ This is only used when ``mlp_hidden_size`` is not set.
261
+ """
262
+
263
+ mlp_hidden_size: Optional[int] = None
264
+ """
265
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
266
+ """
267
+
268
+ activation_type: ActivationType = ActivationType.swiglu
269
+ """
270
+ The activation function to use within the MLP layers.
271
+ """
272
+
273
+ block_type: BlockType = BlockType.sequential
274
+ """
275
+ The transformer block implementation.
276
+ """
277
+
278
+ block_group_size: int = 1
279
+ """
280
+ The number of blocks to group together into a single parent block.
281
+ This has no affect on the number of parameters in the model and is only used to wrap groups
282
+ of blocks together with a single FSDP wrapper during training.
283
+ """
284
+
285
+ alibi: bool = False
286
+ """
287
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
288
+ """
289
+
290
+ alibi_bias_max: float = 8.0
291
+ """
292
+ Maximum absolute value of ALiBi bias.
293
+ """
294
+
295
+ rope: bool = False
296
+ """
297
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
298
+ """
299
+
300
+ rope_full_precision: bool = True
301
+ """
302
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
303
+ apply RoPE at the precision of the input.
304
+ """
305
+
306
+ flash_attention: bool = False
307
+ """
308
+ If ``True``, use ``FlashAttention``.
309
+ """
310
+
311
+ attention_dropout: float = 0.1
312
+ """
313
+ The dropout probability within the attention modules.
314
+ """
315
+
316
+ multi_query_attention: Optional[bool] = None
317
+ """
318
+ Deprecated. Use n_kv_heads instead.
319
+ """
320
+
321
+ attention_layer_norm: bool = False
322
+ """
323
+ Apply layer norm to the keys and queries within the attention mechanism.
324
+ This can help stabilize training.
325
+ """
326
+
327
+ residual_dropout: float = 0.1
328
+ """
329
+ The dropout probability for the MLP and attention output within each block.
330
+ """
331
+
332
+ embedding_dropout: float = 0.1
333
+ """
334
+ The dropout probability for embeddings.
335
+ """
336
+
337
+ layer_norm_type: LayerNormType = LayerNormType.default
338
+ """
339
+ The layernorm implementation to use.
340
+ """
341
+
342
+ layer_norm_with_affine: bool = True
343
+ """
344
+ Whether to include bias and weight parameters for the layer norms.
345
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
346
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
347
+ to ``False``.
348
+ """
349
+
350
+ attention_layer_norm_with_affine: bool = True
351
+ """
352
+ Toggle affine transform for the QK norms.
353
+ """
354
+
355
+ max_sequence_length: int = 1024
356
+ """
357
+ The maximum input sequence length supported by the model.
358
+ """
359
+
360
+ include_bias: bool = True
361
+ """
362
+ Whether or not to include bias parameters in linear layers.
363
+ In PaLM, they got rid of all bias terms because they found that large
364
+ models tend to have near 0 bias terms anyway.
365
+ """
366
+
367
+ bias_for_layer_norm: Optional[bool] = None
368
+ """
369
+ Whether or not to include bias parameters in layer norm.
370
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
371
+ layer norm.
372
+ When this is None (the default), it inherits the setting from include_bias.
373
+ """
374
+
375
+ scale_logits: bool = False
376
+ """
377
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
378
+ """
379
+
380
+ vocab_size: int = 50257
381
+ """
382
+ Vocabulary size of the model.
383
+ """
384
+
385
+ embedding_size: Optional[int] = 50304
386
+ """
387
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
388
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
389
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
390
+ substantially.
391
+ """
392
+
393
+ weight_tying: bool = True
394
+ """
395
+ Whether to tie output linear weights to the input embedding.
396
+ """
397
+
398
+ eos_token_id: int = 50256
399
+ """
400
+ The ID of the end-of-sentence special token.
401
+ """
402
+
403
+ pad_token_id: int = 50256
404
+ """
405
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
406
+ """
407
+
408
+ init_device: Optional[str] = None
409
+ """
410
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
411
+ """
412
+
413
+ init_fn: InitFnType = InitFnType.normal
414
+ """
415
+ The weight initialization strategy.
416
+ """
417
+
418
+ init_std: float = 0.02
419
+ """
420
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
421
+ as "normal".
422
+ """
423
+
424
+ init_cutoff_factor: Optional[float] = None
425
+ """
426
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
427
+ as "normal". Setting this to None means values are not cutoff.
428
+ """
429
+
430
+ precision: Optional[str] = None
431
+ """
432
+ Precision used to train/evaluate with. You shouldn't set this directly.
433
+ See :data:`TrainConfig.precision` instead.
434
+ """
435
+
436
+ @property
437
+ def effective_n_kv_heads(self) -> int:
438
+ if self.n_kv_heads is None:
439
+ if self.multi_query_attention is True:
440
+ return 1
441
+ else:
442
+ return self.n_heads
443
+ else:
444
+ if self.multi_query_attention is None:
445
+ return self.n_kv_heads
446
+ if self.multi_query_attention:
447
+ n_kv_heads_should_be = 1
448
+ else:
449
+ n_kv_heads_should_be = self.n_heads
450
+ if self.n_kv_heads == n_kv_heads_should_be:
451
+ return n_kv_heads_should_be
452
+ else:
453
+ raise OLMoConfigurationError(
454
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
455
+ )
456
+
457
+
458
+ class OptimizerType(StrEnum):
459
+ lionw = "lionw"
460
+ adamw = "adamw"
461
+
462
+
463
+ @dataclass
464
+ class OptimizerConfig(BaseConfig):
465
+ name: OptimizerType = OptimizerType.lionw
466
+ learning_rate: float = 1.0e-4
467
+ weight_decay: float = 0.01
468
+ betas: Tuple[float, float] = (0.9, 0.95)
469
+
470
+ no_decay_norm_and_bias: Optional[bool] = None
471
+ """
472
+ Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
473
+ """
474
+
475
+ decay_norm_and_bias: bool = False
476
+ decay_embeddings: bool = False
477
+ metrics_log_interval: Optional[int] = None
478
+ """
479
+ The interval with which to collect and log detailed parameter-specific metrics.
480
+ This only applies when logging to W&B, since these metrics won't be logged to the console.
481
+ If not set, defaults to the wandb `log_interval`.
482
+ """
483
+
484
+ def __post_init__(self):
485
+ self.betas = tuple(self.betas) # type: ignore[assignment]
486
+
487
+ @classmethod
488
+ def update_legacy_settings(cls, config: D) -> D:
489
+ new_config = config.copy()
490
+ if om.is_dict(new_config):
491
+ assert isinstance(new_config, DictConfig)
492
+
493
+ if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
494
+ new_config.name = "lionw"
495
+ if hasattr(new_config, "eps"):
496
+ del new_config.eps
497
+
498
+ return new_config
499
+
500
+
501
+ class SchedulerType(StrEnum):
502
+ cosine_with_warmup = "cosine_with_warmup"
503
+ linear_with_warmup = "linear_with_warmup"
504
+ inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
505
+ max_scheduler = "max_scheduler"
506
+ constant = "constant"
507
+
508
+
509
+ class SchedulerUnits(StrEnum):
510
+ steps = "steps"
511
+ tokens = "tokens"
512
+
513
+
514
+ @dataclass
515
+ class SchedulerConfig(BaseConfig):
516
+ name: SchedulerType = SchedulerType.cosine_with_warmup
517
+ units: SchedulerUnits = SchedulerUnits.steps
518
+ t_warmup: Union[int, float] = 100
519
+ t_max: Optional[Union[int, float]] = None
520
+ alpha_f: float = 0.1
521
+
522
+ grad_clip_warmup_steps: Optional[Union[int, float]] = None
523
+ """
524
+ The warmup period for which the max grad norm (or norm ratio) will be set to its
525
+ warmup value of `max_grad_norm * grad_clip_warmup_factor`.
526
+ """
527
+
528
+ grad_clip_warmup_factor: Optional[float] = None
529
+ """
530
+ The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
531
+ vs after the warmup period.
532
+ """
533
+
534
+ warmup_min_lr: Optional[float] = None
535
+ """
536
+ The starting LR during the warmup period. If not set this defaults to 10% of
537
+ the target LR.
538
+ """
539
+
540
+
541
+ class PaddingDirection(StrEnum):
542
+ right = "right"
543
+ left = "left"
544
+
545
+
546
+ @dataclass
547
+ class DataConfig(BaseConfig):
548
+ paths: Optional[List[str]] = None
549
+ datasets: Optional[Dict[str, List[str]]] = None
550
+ label_mask_paths: Optional[List[str]] = None
551
+ pad_direction: PaddingDirection = PaddingDirection.right
552
+ generate_attention_mask: bool = False
553
+ num_workers: int = 0
554
+ drop_last: bool = False
555
+ pin_memory: bool = False
556
+ prefetch_factor: Optional[int] = None
557
+ persistent_workers: bool = False
558
+ timeout: int = 0
559
+ seed: Optional[int] = None
560
+
561
+
562
+ class EvaluatorType(StrEnum):
563
+ downstream = "downstream"
564
+ lm = "lm"
565
+
566
+
567
+ @dataclass
568
+ class EvaluatorConfig(BaseConfig):
569
+ label: str
570
+ type: EvaluatorType = EvaluatorType.lm
571
+ data: DataConfig = field(default_factory=DataConfig)
572
+ device_eval_batch_size: Optional[int] = None
573
+ subset_num_batches: Optional[int] = None
574
+
575
+
576
+ class TruncationDirection(StrEnum):
577
+ right = "right"
578
+ left = "left"
579
+
580
+
581
+ @dataclass
582
+ class TokenizerConfig(BaseConfig):
583
+ identifier: str = "gpt2"
584
+ truncate_direction: TruncationDirection = TruncationDirection.right
585
+
586
+
587
+ @dataclass
588
+ class WandbConfig(BaseConfig):
589
+ project: Optional[str] = None
590
+ entity: Optional[str] = "zehui127-imperial-college-london"
591
+ group: Optional[str] = None
592
+ name: Optional[str] = None
593
+ tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
594
+ log_artifacts: bool = False
595
+ rank_zero_only: bool = True
596
+ log_interval: int = 50
597
+
598
+
599
+ @dataclass
600
+ class SpeedMonitorConfig(BaseConfig):
601
+ window_size: int = 100
602
+ gpu_flops_available: Optional[Union[float, int]] = None
603
+
604
+
605
+ @dataclass
606
+ class CompilerConfig(BaseConfig):
607
+ mode: Optional[str] = None
608
+ """
609
+ The mode to compile the model in. At the moment this can be "default",
610
+ "reduce-overhead" (useful for smaller models/batches), or "max-autotune"
611
+ (the fastest for larger models, but takes a long time to compile).
612
+ """
613
+
614
+ fullgraph: bool = False
615
+ """
616
+ Whether it is OK to break model into several subgraphs when compiling.
617
+ Note that this is not compatible with FSDP.
618
+ """
619
+
620
+ backend: str = "inductor"
621
+ """
622
+ The backend to use.
623
+ """
624
+
625
+
626
+ class FSDPWrapStrategy(StrEnum):
627
+ by_block = "by_block"
628
+ """
629
+ Wrap each OLMo block with its own FSDP instance.
630
+ """
631
+
632
+ by_block_and_size = "by_block_and_size"
633
+ """
634
+ Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
635
+ """
636
+
637
+ by_block_group = "by_block_group"
638
+ """
639
+ Wrap each block group together into its own FSDP instance.
640
+ This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
641
+ """
642
+
643
+ by_block_group_and_size = "by_block_group_and_size"
644
+ """
645
+ Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
646
+ """
647
+
648
+ size_based = "size_based"
649
+ """
650
+ Used PyTorch's default size-based auto wrap policy.
651
+ """
652
+
653
+ one_in_two = "one_in_two"
654
+ one_in_three = "one_in_three"
655
+ one_in_four = "one_in_four"
656
+ one_in_five = "one_in_five"
657
+
658
+
659
+ class FSDPPrecision(StrEnum):
660
+ pure = "pure"
661
+ """
662
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
663
+ and ``buffer_dtype`` all set to the autocast precision data type.
664
+ """
665
+
666
+ mixed = "mixed"
667
+ """
668
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
669
+ set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
670
+ """
671
+
672
+
673
+ @dataclass
674
+ class FSDPConfig(BaseConfig):
675
+ use_orig_params: bool = True
676
+ """
677
+ This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
678
+ """
679
+
680
+ sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
681
+
682
+ wrapping_strategy: Optional[FSDPWrapStrategy] = None
683
+ """
684
+ The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
685
+ FSDP instance.
686
+ """
687
+
688
+ precision: FSDPPrecision = FSDPPrecision.pure
689
+
690
+
691
+ class CheckpointType(StrEnum):
692
+ sharded = "sharded"
693
+ unsharded = "unsharded"
694
+ sharded_ephemeral = "sharded_ephemeral"
695
+
696
+
697
+ class ShardedCheckpointerType(StrEnum):
698
+ torch_new = "torch_new"
699
+ torch_legacy = "torch_legacy"
700
+ local = "local"
701
+ olmo_core = "olmo_core"
702
+
703
+
704
+ class ActivationCheckpointingStrategy(StrEnum):
705
+ whole_layer = "whole_layer"
706
+ """
707
+ Checkpoint every transformer layer.
708
+ """
709
+
710
+ one_in_two = "one_in_two"
711
+ """
712
+ Checkpoint one in two transformer layers.
713
+ """
714
+
715
+ one_in_three = "one_in_three"
716
+ """
717
+ Checkpoint one in three transformer layers.
718
+ """
719
+
720
+ one_in_four = "one_in_four"
721
+ """
722
+ Checkpoint one in four transformer layers.
723
+ """
724
+
725
+ two_in_three = "two_in_three"
726
+ """
727
+ Checkpoint two out of every three transformer layers.
728
+ """
729
+
730
+ three_in_four = "three_in_four"
731
+ """
732
+ Checkpoint three out of four of every transformer layers.
733
+ """
734
+
735
+ fine_grained = "fine_grained"
736
+ """
737
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
738
+ """
739
+
740
+
741
+ @dataclass
742
+ class TrainConfig(BaseConfig):
743
+ """
744
+ OLMo training configuration.
745
+ """
746
+
747
+ run_name: Optional[str] = None
748
+ """
749
+ The name of the run.
750
+ """
751
+
752
+ seed: int = 6198
753
+ """
754
+ Used to seed all initial RNG states.
755
+ """
756
+
757
+ epoch: Optional[int] = None
758
+ """
759
+ Increment this when starting a new epoch.
760
+ """
761
+
762
+ dry_run: bool = False
763
+ """
764
+ If ``True``, don't actually train.
765
+ """
766
+
767
+ model: ModelConfig = field(default_factory=ModelConfig)
768
+ """
769
+ OLMo Model configuration.
770
+ """
771
+
772
+ optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
773
+ """
774
+ Optimizer configuration.
775
+ """
776
+
777
+ scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
778
+ """
779
+ Learning rate scheduler configuration.
780
+ """
781
+
782
+ data: DataConfig = field(default_factory=DataConfig)
783
+ """
784
+ Training data configuration.
785
+ """
786
+
787
+ restore_dataloader: bool = True
788
+ """
789
+ When restarting, restore the data loader to where it left off.
790
+ If you restarting in order to train on a different dataset, set this to ``False``.
791
+ """
792
+
793
+ fast_forward_batches: Optional[int] = None
794
+ """
795
+ When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
796
+ This can be useful when restarting due to a loss spike in order to skip the data that
797
+ corresponded to the spike.
798
+ """
799
+
800
+ evaluators: List[EvaluatorConfig] = field(default_factory=list)
801
+ """
802
+ Evaluation configurations.
803
+ """
804
+
805
+ eval_interval: int = 1000
806
+ """
807
+ How often (in terms of batches) to run evaluations.
808
+ """
809
+
810
+ tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
811
+ """
812
+ Tokenizer configuration.
813
+ """
814
+
815
+ save_folder: str = "./"
816
+ """
817
+ The directory to save checkpoints to.
818
+ """
819
+
820
+ remote_save_folder: Optional[str] = None
821
+ """
822
+ A folder in a cloud bucket to upload saved checkpoints to.
823
+ """
824
+
825
+ canceled_check_interval: int = 50
826
+ """
827
+ How often (in batches) to check if the run has been canceled or reached its time limit.
828
+ """
829
+
830
+ save_interval: int = 1000
831
+ """
832
+ How often (in terms of steps) to save sharded training state checkpoints.
833
+ """
834
+
835
+ save_interval_unsharded: Optional[int] = None
836
+ """
837
+ How often (if at all) to save unsharded training state checkpoint.
838
+ For large models it can be costly to save these, so it usually makes sense to save
839
+ these less often than regular (sharded) training checkpoints.
840
+ """
841
+
842
+ save_interval_ephemeral: Optional[int] = None
843
+ """
844
+ How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
845
+ as those saved every `save_interval` except that at most only the most recent one of these is kept.
846
+ This is useful when you want to checkpoint often for restarts in case of failures, but don't
847
+ want to keep the majority of these checkpoints.
848
+
849
+ For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
850
+ a temporary checkpoint every 100 steps in case your job fails. In that case you would
851
+ set `save_interval=1000` and `save_interval_ephemeral=100`.
852
+ """
853
+
854
+ save_num_checkpoints_to_keep: int = -1
855
+ """
856
+ How many sharded checkpoints to keep.
857
+ """
858
+
859
+ save_num_unsharded_checkpoints_to_keep: int = -1
860
+ """
861
+ How many unsharded checkpoints to keep.
862
+ """
863
+
864
+ save_overwrite: bool = False
865
+ """
866
+ If ``True``, overwrite any conflicting checkpoint files.
867
+ """
868
+
869
+ force_save_unsharded: bool = False
870
+ """
871
+ Save an unsharded checkpoint before training (even during a dry run).
872
+ Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
873
+ checkpoint into an unsharded checkpoint.
874
+ """
875
+
876
+ no_pre_train_checkpoint: bool = False
877
+ """
878
+ Skip saving pre-train checkpoint.
879
+ """
880
+
881
+ load_path: Optional[str] = None
882
+ """
883
+ The path to a training checkpoint to restore/resume from.
884
+
885
+ Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
886
+ a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
887
+ For example,
888
+
889
+ ```bash
890
+ --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
891
+ ```
892
+ """
893
+
894
+ load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
895
+ """
896
+ The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
897
+ """
898
+
899
+ reset_optimizer_state: bool = False
900
+ """
901
+ When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
902
+ We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
903
+ curve (according to the current learning rate schedule settings), and continues from there.
904
+ """
905
+
906
+ reset_trainer_state: bool = False
907
+ """
908
+ When this is set we don't restore the trainer state from a checkpoint.
909
+ """
910
+
911
+ sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
912
+ """
913
+ The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
914
+ """
915
+
916
+ new_style_checkpoints: Optional[bool] = None
917
+ """
918
+ Deprecated. Use ``sharded_checkpointer`` instead.
919
+ """
920
+
921
+ max_duration: Union[int, str] = 10000
922
+ """
923
+ How long to train for.
924
+
925
+ If specified without a unit (the default), the units are assumed to be steps.
926
+ You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
927
+ 2 trillion tokens.
928
+ """
929
+
930
+ global_train_batch_size: int = 512
931
+ """
932
+ The effective global batch size.
933
+ """
934
+
935
+ device_train_batch_size: Optional[int] = None # calculated automatically
936
+ """
937
+ Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
938
+ """
939
+
940
+ device_train_microbatch_size: int = 16
941
+ """
942
+ The number of instances passed to the model in a single forward-backward pass. You should set
943
+ this as large as you can based on available GPU memory.
944
+ """
945
+
946
+ device_eval_batch_size: int = 16
947
+ """
948
+ The number of evaluation instances passed to the model in a single forward pass on each device.
949
+ """
950
+
951
+ eval_subset_num_batches: int = -1
952
+ """
953
+ The number of batches to use for downstream evaluation from each dataset.
954
+ """
955
+
956
+ eval_on_load: bool = False
957
+ """
958
+ When resuming from a checkpoint, run the evaluation loop right away.
959
+ """
960
+
961
+ device_train_grad_accum: Optional[int] = None # calculated automatically
962
+ """
963
+ Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
964
+ """
965
+
966
+ max_grad_norm: Optional[float] = None
967
+ """
968
+ Clip gradient norms to this value if set.
969
+ """
970
+
971
+ max_grad_norm_ratio: Optional[float] = None
972
+ """
973
+ If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
974
+ This takes priority over `max_grad_norm` when set.
975
+ """
976
+
977
+ precision: Optional[str] = None
978
+ """
979
+ Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
980
+ """
981
+
982
+ wandb: Optional[WandbConfig] = None
983
+ """
984
+ Weights & Biases configuration.
985
+ """
986
+
987
+ speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
988
+ """
989
+ Speed monitor configuration.
990
+ """
991
+
992
+ console_log_interval: int = 1
993
+ """
994
+ How often to log to the console.
995
+ """
996
+
997
+ gen1_gc_interval: Optional[int] = 1
998
+ """
999
+ How often (in steps) to run generation 1 garbage collection.
1000
+ Set to ``None`` to use automatic garbage collection (i.e. we don't mess with it).
1001
+ """
1002
+
1003
+ compile: Optional[CompilerConfig] = None
1004
+ """
1005
+ Settings for compiling the model with ``torch.compile()``.
1006
+ """
1007
+
1008
+ fsdp: FSDPConfig = field(default_factory=FSDPConfig)
1009
+ """
1010
+ Fully sharded data parallel settings.
1011
+ """
1012
+
1013
+ softmax_auxiliary_loss: bool = False
1014
+ """
1015
+ If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
1016
+ normalizing term to be close to 0.
1017
+ """
1018
+
1019
+ time_limit: Optional[float] = 60 * 60 * 47.5
1020
+ """
1021
+ The maximum amount of time to train for before saving a checkpoint and ending early.
1022
+ On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time
1023
+ to write out a final checkpoint.
1024
+ """
1025
+
1026
+ extra_steps_after_cancel: int = 10
1027
+ """
1028
+ Under certain conditions when a run is canceled we train for a few extra steps after saving
1029
+ the final checkpoint so that when the run is restarted from the latest checkpoint we have some
1030
+ overlap in metrics.
1031
+ """
1032
+
1033
+ early_stopping_factor: Optional[float] = None
1034
+
1035
+ save_data_indices: bool = True
1036
+ """
1037
+ Save training data indices from each batch for each worker.
1038
+ """
1039
+
1040
+ python_profiling: bool = False
1041
+ """
1042
+ Whether to run the Python profiler on batches 6, 7, and 8.
1043
+ """
1044
+
1045
+ torch_profiling: bool = False
1046
+ """
1047
+ Whether to run the PyTorch profiler on batches 6, 7, and 8.
1048
+ """
1049
+
1050
+ stop_at: Optional[int] = None
1051
+ """
1052
+ Stop at a specific step.
1053
+ """
1054
+
1055
+ stop_after: Optional[int] = None
1056
+ """
1057
+ Stop after a specific number of steps.
1058
+ """
1059
+
1060
+ activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
1061
+ """
1062
+ The activation checkpointing strategy to use.
1063
+ """
1064
+
1065
+ fused_loss: Optional[bool] = None
1066
+ """
1067
+ Whether to use the fused CE loss function from `flash-attn`.
1068
+ """
1069
+
1070
+ @property
1071
+ def autocast_precision(self) -> torch.dtype:
1072
+ if self.precision == "amp_bf16":
1073
+ return torch.bfloat16
1074
+ elif self.precision == "amp_fp16":
1075
+ return torch.float16
1076
+ elif self.precision == "fp32":
1077
+ return torch.float32
1078
+ else:
1079
+ raise ValueError(f"Unexpected precision type '{self.precision}'")
1080
+
1081
+ @property
1082
+ def fsdp_precision(self) -> MixedPrecision:
1083
+ if self.fsdp.precision == FSDPPrecision.pure:
1084
+ return MixedPrecision(
1085
+ param_dtype=self.autocast_precision,
1086
+ reduce_dtype=self.autocast_precision,
1087
+ buffer_dtype=self.autocast_precision,
1088
+ )
1089
+ elif self.fsdp.precision == FSDPPrecision.mixed:
1090
+ return MixedPrecision(
1091
+ param_dtype=self.autocast_precision,
1092
+ reduce_dtype=torch.float32,
1093
+ buffer_dtype=self.autocast_precision,
1094
+ )
1095
+ else:
1096
+ raise NotImplementedError(f"{self.fsdp.precision}")
1097
+
1098
+ @classmethod
1099
+ def update_legacy_settings(cls, config: D) -> D:
1100
+ new_config = config.copy()
1101
+ if om.is_dict(new_config):
1102
+ assert isinstance(new_config, DictConfig)
1103
+
1104
+ if hasattr(new_config, "activation_checkpointing"):
1105
+ if new_config.activation_checkpointing is False:
1106
+ new_config.activation_checkpointing = None
1107
+ if new_config.activation_checkpointing is True:
1108
+ new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
1109
+
1110
+ if hasattr(new_config, "optimizer"):
1111
+ new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
1112
+
1113
+ return new_config
model/configuration_olmo.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OLMo configuration
3
+ """
4
+
5
+ from transformers import AutoConfig, PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ from .config import ModelConfig
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class OLMoConfig(PretrainedConfig):
14
+ model_type = "olmo-gfm"
15
+ keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
16
+
17
+ def __init__(self, use_cache: bool = False, num_labels: int = 2,**kwargs):
18
+ model_config = ModelConfig()
19
+ all_kwargs = model_config.asdict()
20
+ all_kwargs.update(kwargs)
21
+ all_kwargs.update({"use_cache": use_cache, "num_labels": num_labels})
22
+ all_kwargs.update(
23
+ {
24
+ "architectures": all_kwargs.get("architectures", ["OLMoModelForCausalLM"])
25
+ or ["OLMoModelForCausalLM"]
26
+ }
27
+ )
28
+ super().__init__(**all_kwargs)
29
+
30
+ @property
31
+ def num_attention_heads(self):
32
+ return self.n_heads
33
+
34
+ @property
35
+ def num_hidden_layers(self):
36
+ return self.n_layers
37
+
38
+ @property
39
+ def hidden_size(self):
40
+ return self.d_model
41
+
42
+
43
+ # Register the config class so that it is available for transformer pipelines, auto-loading etc.
44
+ AutoConfig.register("olmo-gfm", OLMoConfig)
model/exceptions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "OLMoError",
3
+ "OLMoConfigurationError",
4
+ "OLMoCliError",
5
+ "OLMoEnvironmentError",
6
+ "OLMoNetworkError",
7
+ "OLMoCheckpointError",
8
+ ]
9
+
10
+
11
+ class OLMoError(Exception):
12
+ """
13
+ Base class for all custom OLMo exceptions.
14
+ """
15
+
16
+
17
+ class OLMoConfigurationError(OLMoError):
18
+ """
19
+ An error with a configuration file.
20
+ """
21
+
22
+
23
+ class OLMoCliError(OLMoError):
24
+ """
25
+ An error from incorrect CLI usage.
26
+ """
27
+
28
+
29
+ class OLMoEnvironmentError(OLMoError):
30
+ """
31
+ An error from incorrect environment variables.
32
+ """
33
+
34
+
35
+ class OLMoNetworkError(OLMoError):
36
+ """
37
+ An error with a network request.
38
+ """
39
+
40
+
41
+ class OLMoCheckpointError(OLMoError):
42
+ """
43
+ An error occurred reading or writing from a checkpoint.
44
+ """
45
+
46
+
47
+ class OLMoThreadError(Exception):
48
+ """
49
+ Raised when a thread fails.
50
+ """
model/initialization.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .config import InitFnType, ModelConfig
8
+ from .util import StrEnum
9
+
10
+ __all__ = ["init_weights", "ModuleType"]
11
+
12
+
13
+ class ModuleType(StrEnum):
14
+ in_module = "in"
15
+ out_module = "out"
16
+ emb = "emb"
17
+ final_out = "final_out"
18
+
19
+
20
+ def init_weights(
21
+ config: ModelConfig,
22
+ module: Union[nn.Linear, nn.Embedding],
23
+ d: Optional[int] = None,
24
+ layer_id: Optional[int] = None,
25
+ std_factor: float = 1.0,
26
+ type_of_module: Optional[ModuleType] = None,
27
+ ) -> None:
28
+ """
29
+ Initialize weights of a linear or embedding module.
30
+
31
+ :param config: The model config.
32
+ :param module: The linear or embedding submodule to initialize.
33
+ :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
34
+ for fused layers.
35
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
36
+ ``1 / sqrt(2 * (layer_id + 1))``.
37
+ """
38
+ d = d if d is not None else config.d_model
39
+ if config.init_fn == InitFnType.normal:
40
+ std = config.init_std * std_factor
41
+ if config.init_cutoff_factor is not None:
42
+ cutoff_value = config.init_cutoff_factor * std
43
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
44
+ else:
45
+ nn.init.normal_(module.weight, mean=0.0, std=std)
46
+ elif config.init_fn == InitFnType.mitchell:
47
+ std = std_factor / math.sqrt(d)
48
+ if layer_id is not None:
49
+ std = std / math.sqrt(2 * (layer_id + 1))
50
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
51
+ elif config.init_fn == InitFnType.kaiming_normal:
52
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
53
+ elif config.init_fn == InitFnType.fan_in:
54
+ std = std_factor / math.sqrt(d)
55
+ nn.init.normal_(module.weight, mean=0.0, std=std)
56
+ elif config.init_fn == InitFnType.full_megatron:
57
+ if type_of_module is None:
58
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
59
+
60
+ cutoff_factor = config.init_cutoff_factor
61
+ if cutoff_factor is None:
62
+ cutoff_factor = 3
63
+
64
+ if type_of_module == ModuleType.in_module:
65
+ # for att_proj (same as QKV), ff_proj
66
+ std = config.init_std
67
+ elif type_of_module == ModuleType.out_module:
68
+ # for attn_out, ff_out
69
+ std = config.init_std / math.sqrt(2.0 * config.n_layers)
70
+ elif type_of_module == ModuleType.emb:
71
+ # positional embeddings (wpe)
72
+ # token embeddings (wte)
73
+ std = config.init_std
74
+ elif type_of_module == ModuleType.final_out:
75
+ # final output (ff_out)
76
+ std = config.d_model**-0.5
77
+ else:
78
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
79
+ nn.init.trunc_normal_(
80
+ module.weight,
81
+ mean=0.0,
82
+ std=std,
83
+ a=-cutoff_factor * std,
84
+ b=cutoff_factor * std,
85
+ )
86
+ else:
87
+ raise NotImplementedError(config.init_fn)
88
+
89
+ if isinstance(module, nn.Linear):
90
+ if module.bias is not None:
91
+ nn.init.zeros_(module.bias)
92
+
93
+ if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
94
+ with torch.no_grad():
95
+ module.weight.div_(math.sqrt(2 * config.n_layers))
model/model.py ADDED
@@ -0,0 +1,1625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from
3
+ [MosaiclML](https://github.com/mosaicml/examples.git) and
4
+ [minGPT](https://github.com/karpathy/minGPT.git)
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import math
11
+ import sys
12
+ from abc import abstractmethod
13
+ from collections import defaultdict
14
+ from functools import partial
15
+ from typing import (
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ NamedTuple,
21
+ Optional,
22
+ Sequence,
23
+ Set,
24
+ Tuple,
25
+ cast,
26
+ )
27
+
28
+ import torch
29
+ import torch.backends.cuda
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch import einsum
33
+
34
+ from .aliases import PathOrStr
35
+ from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
36
+ from .config import (
37
+ ActivationCheckpointingStrategy,
38
+ ActivationType,
39
+ BlockType,
40
+ CheckpointType,
41
+ FSDPWrapStrategy,
42
+ LayerNormType,
43
+ ModelConfig,
44
+ )
45
+ from .exceptions import OLMoConfigurationError
46
+ from .initialization import ModuleType, init_weights
47
+ from .torch_util import ensure_finite_
48
+
49
+ if sys.version_info.minor > 8:
50
+ from collections.abc import MutableMapping
51
+ elif sys.version_info.minor == 8:
52
+ from typing import MutableMapping
53
+ else:
54
+ raise SystemExit("This script supports Python 3.8 or higher")
55
+
56
+ __all__ = [
57
+ "LayerNormBase",
58
+ "LayerNorm",
59
+ "RMSLayerNorm",
60
+ "RotaryEmbedding",
61
+ "Activation",
62
+ "GELU",
63
+ "ReLU",
64
+ "SwiGLU",
65
+ "OLMoBlock",
66
+ "OLMoSequentialBlock",
67
+ "OLMo",
68
+ "OLMoOutput",
69
+ "OLMoGenerateOutput",
70
+ ]
71
+
72
+
73
+ log = logging.getLogger(__name__)
74
+
75
+
76
+ def activation_checkpoint_function(cfg: ModelConfig):
77
+ preserve_rng_state = (
78
+ (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
79
+ )
80
+ from torch.utils.checkpoint import checkpoint
81
+
82
+ return partial(
83
+ checkpoint,
84
+ preserve_rng_state=preserve_rng_state,
85
+ use_reentrant=False,
86
+ )
87
+
88
+
89
+ def should_checkpoint_block(strategy: Optional[ActivationCheckpointingStrategy], block_idx: int) -> bool:
90
+ if strategy is None:
91
+ return False
92
+ elif (
93
+ (strategy == ActivationCheckpointingStrategy.whole_layer)
94
+ or (strategy == ActivationCheckpointingStrategy.one_in_two and block_idx % 2 == 0)
95
+ or (strategy == ActivationCheckpointingStrategy.one_in_three and block_idx % 3 == 0)
96
+ or (strategy == ActivationCheckpointingStrategy.one_in_four and block_idx % 4 == 0)
97
+ or (strategy == ActivationCheckpointingStrategy.two_in_three and block_idx % 3 != 0)
98
+ or (strategy == ActivationCheckpointingStrategy.three_in_four and block_idx % 4 != 0)
99
+ ):
100
+ return True
101
+ else:
102
+ return False
103
+
104
+
105
+ class BufferCache(dict, MutableMapping[str, torch.Tensor]):
106
+ """
107
+ Cache for attention biases and other things that would normally be stored as buffers.
108
+ We avoid using buffers because we've run into various issues doing so with FSDP.
109
+ In general it appears the way FSDP handles buffers is not well-defined.
110
+ It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
111
+ since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
112
+ NaNs when they're synchronized due to casting or some other issue.
113
+ """
114
+
115
+
116
+ def _non_meta_init_device(config: ModelConfig) -> torch.device:
117
+ if config.init_device is not None and config.init_device != "meta":
118
+ return torch.device(config.init_device)
119
+ else:
120
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
+
122
+
123
+ class Dropout(nn.Dropout):
124
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
125
+ if self.p == 0.0:
126
+ return input
127
+ else:
128
+ return F.dropout(input, self.p, self.training, self.inplace)
129
+
130
+
131
+ class LayerNormBase(nn.Module):
132
+ def __init__(
133
+ self,
134
+ config: ModelConfig,
135
+ *,
136
+ size: Optional[int] = None,
137
+ elementwise_affine: Optional[bool] = True,
138
+ eps: float = 1e-05,
139
+ ):
140
+ super().__init__()
141
+ self.config = config
142
+ self.eps = eps
143
+ self.normalized_shape = (size or config.d_model,)
144
+ if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
145
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
146
+ use_bias = self.config.bias_for_layer_norm
147
+ if use_bias is None:
148
+ use_bias = self.config.include_bias
149
+ if use_bias:
150
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
151
+ else:
152
+ self.register_parameter("bias", None)
153
+ else:
154
+ self.register_parameter("bias", None)
155
+ self.register_parameter("weight", None)
156
+
157
+ @abstractmethod
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ raise NotImplementedError
160
+
161
+ @classmethod
162
+ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
163
+ if config.layer_norm_type == LayerNormType.default:
164
+ return LayerNorm(config, size=size, low_precision=False, **kwargs)
165
+ elif config.layer_norm_type == LayerNormType.low_precision:
166
+ return LayerNorm(config, size=size, low_precision=True, **kwargs)
167
+ elif config.layer_norm_type == LayerNormType.rms:
168
+ return RMSLayerNorm(config, size=size, **kwargs)
169
+ else:
170
+ raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
171
+
172
+ def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
173
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
174
+ # `is_autocast_cpu_enabled()` for CPU autocast.
175
+ # See https://github.com/pytorch/pytorch/issues/110966.
176
+ if tensor.device.type == "cuda" and torch.is_autocast_enabled():
177
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
178
+ elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
179
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
180
+ else:
181
+ return tensor
182
+
183
+ def reset_parameters(self):
184
+ if self.weight is not None:
185
+ torch.nn.init.ones_(self.weight) # type: ignore
186
+ if self.bias is not None:
187
+ torch.nn.init.zeros_(self.bias) # type: ignore
188
+
189
+
190
+ class LayerNorm(LayerNormBase):
191
+ """
192
+ The default :class:`LayerNorm` implementation which can optionally run in low precision.
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ config: ModelConfig,
198
+ size: Optional[int] = None,
199
+ low_precision: bool = False,
200
+ elementwise_affine: Optional[bool] = None,
201
+ eps: float = 1e-05,
202
+ ):
203
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
204
+ self.low_precision = low_precision
205
+
206
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
207
+ if self.low_precision:
208
+ module_device = x.device
209
+ downcast_x = self._cast_if_autocast_enabled(x)
210
+ downcast_weight = (
211
+ self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
212
+ )
213
+ downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
214
+ with torch.autocast(enabled=False, device_type=module_device.type):
215
+ return F.layer_norm(
216
+ downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
217
+ )
218
+ else:
219
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
220
+
221
+
222
+ class RMSLayerNorm(LayerNormBase):
223
+ """
224
+ RMS layer norm, a simplified :class:`LayerNorm` implementation
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ config: ModelConfig,
230
+ size: Optional[int] = None,
231
+ elementwise_affine: Optional[bool] = None,
232
+ eps: float = 1e-5,
233
+ ):
234
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
235
+
236
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
237
+ with torch.autocast(enabled=False, device_type=x.device.type):
238
+ og_dtype = x.dtype
239
+ x = x.to(torch.float32)
240
+ variance = x.pow(2).mean(-1, keepdim=True)
241
+ x = x * torch.rsqrt(variance + self.eps)
242
+ x = x.to(og_dtype)
243
+
244
+ if self.weight is not None:
245
+ if self.bias is not None:
246
+ return self.weight * x + self.bias
247
+ else:
248
+ return self.weight * x
249
+ else:
250
+ return x
251
+
252
+
253
+ class RotaryEmbedding(nn.Module):
254
+ """
255
+ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
256
+ """
257
+
258
+ def __init__(self, config: ModelConfig, cache: BufferCache):
259
+ super().__init__()
260
+ self.config = config
261
+ self.__cache = cache
262
+ # Warm up cache.
263
+ self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
264
+
265
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
266
+ if (
267
+ (pos_sin := self.__cache.get("rope_pos_sin")) is not None
268
+ and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
269
+ and pos_sin.shape[-2] >= seq_len
270
+ and pos_cos.shape[-2] >= seq_len
271
+ ):
272
+ if pos_sin.device != device:
273
+ pos_sin = pos_sin.to(device)
274
+ self.__cache["rope_pos_sin"] = pos_sin
275
+ if pos_cos.device != device:
276
+ pos_cos = pos_cos.to(device)
277
+ self.__cache["rope_pos_cos"] = pos_cos
278
+ return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
279
+
280
+ with torch.autocast(device.type, enabled=False):
281
+ dim = self.config.d_model // self.config.n_heads
282
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
283
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
284
+ freqs = einsum("i , j -> i j", seq, inv_freq)
285
+ positions = torch.cat((freqs, freqs), dim=-1)
286
+ pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
287
+ self.__cache["rope_pos_sin"] = pos_sin
288
+ self.__cache["rope_pos_cos"] = pos_cos
289
+ return pos_sin, pos_cos
290
+
291
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
292
+ B, nh, T, hs = x.size()
293
+ x = x.view(B, nh, T, 2, hs // 2)
294
+ x1, x2 = x.unbind(dim=-2)
295
+ return torch.cat((-x2, x1), dim=-1)
296
+
297
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
298
+ return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
299
+
300
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
301
+ if self.config.rope_full_precision:
302
+ q_, k_ = q.float(), k.float()
303
+ else:
304
+ q_, k_ = q, k
305
+
306
+ with torch.autocast(q.device.type, enabled=False):
307
+ query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
308
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
309
+ pos_sin = pos_sin.type_as(q_)
310
+ pos_cos = pos_cos.type_as(q_)
311
+ q_ = self.apply_rotary_pos_emb(
312
+ pos_sin[:, :, key_len - query_len : key_len, :],
313
+ pos_cos[:, :, key_len - query_len : key_len, :],
314
+ q_,
315
+ )
316
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
317
+ return q_.type_as(q), k_.type_as(k)
318
+
319
+
320
+ class Activation(nn.Module):
321
+ def __init__(self, config: ModelConfig):
322
+ super().__init__()
323
+ self.config = config
324
+
325
+ @abstractmethod
326
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
327
+ raise NotImplementedError
328
+
329
+ @property
330
+ @abstractmethod
331
+ def output_multiplier(self) -> float:
332
+ raise NotImplementedError
333
+
334
+ @classmethod
335
+ def build(cls, config: ModelConfig) -> Activation:
336
+ if config.activation_type == ActivationType.gelu:
337
+ return cast(Activation, GELU(approximate="none"))
338
+ elif config.activation_type == ActivationType.relu:
339
+ return cast(Activation, ReLU(inplace=False))
340
+ elif config.activation_type == ActivationType.swiglu:
341
+ return SwiGLU(config)
342
+ else:
343
+ raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
344
+
345
+
346
+ class GELU(nn.GELU):
347
+ @property
348
+ def output_multiplier(self) -> float:
349
+ return 1.0
350
+
351
+
352
+ class ReLU(nn.ReLU):
353
+ @property
354
+ def output_multiplier(self) -> float:
355
+ return 1.0
356
+
357
+
358
+ class SwiGLU(Activation):
359
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
360
+ x, gate = x.chunk(2, dim=-1)
361
+ return F.silu(gate) * x
362
+
363
+ @property
364
+ def output_multiplier(self) -> float:
365
+ return 0.5
366
+
367
+
368
+ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
369
+ att_bias = torch.triu(
370
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
371
+ diagonal=1,
372
+ )
373
+ att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
374
+ return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
375
+
376
+
377
+ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
378
+ if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
379
+ if causal_bias.device != device:
380
+ causal_bias = causal_bias.to(device)
381
+ cache["causal_attention_bias"] = causal_bias
382
+ return causal_bias
383
+ with torch.autocast(device.type, enabled=False):
384
+ causal_bias = causal_attention_bias(seq_len, device)
385
+ cache["causal_attention_bias"] = causal_bias
386
+ return causal_bias
387
+
388
+
389
+ def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
390
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
391
+
392
+ # shape: (1, 1, seq_len, seq_len)
393
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
394
+ alibi_bias.abs_().mul_(-1)
395
+
396
+ # shape: (n_heads,)
397
+ m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
398
+ m.mul_(config.alibi_bias_max / config.n_heads)
399
+
400
+ # shape: (1, n_heads, seq_len, seq_len)
401
+ return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
402
+
403
+
404
+ class OLMoBlock(nn.Module):
405
+ """
406
+ A base class for transformer block implementations.
407
+ """
408
+
409
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
410
+ super().__init__()
411
+ self.layer_id = layer_id
412
+ self.config = config
413
+ self.hidden_size = (
414
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
415
+ )
416
+ self.__cache = cache
417
+ assert config.d_model % config.n_heads == 0
418
+
419
+ self._activation_checkpoint_fn = None
420
+
421
+ # Dropout.
422
+ self.dropout = Dropout(config.residual_dropout)
423
+
424
+ # Layer norms.
425
+ self.k_norm: Optional[LayerNormBase] = None
426
+ self.q_norm: Optional[LayerNormBase] = None
427
+ if config.attention_layer_norm:
428
+ assert config.effective_n_kv_heads is not None
429
+ self.k_norm = LayerNormBase.build(
430
+ config,
431
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
432
+ elementwise_affine=config.attention_layer_norm_with_affine,
433
+ )
434
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
435
+
436
+ # Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
437
+ if config.clip_qkv is not None:
438
+ assert config.clip_qkv > 0
439
+
440
+ # Activation function.
441
+ self.act = Activation.build(config)
442
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
443
+
444
+ # Attention output projection.
445
+ self.attn_out = nn.Linear(
446
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
447
+ )
448
+
449
+ # Feed-forward output projection.
450
+ self.ff_out = nn.Linear(
451
+ int(self.act.output_multiplier * self.hidden_size),
452
+ config.d_model,
453
+ bias=config.include_bias,
454
+ device=config.init_device,
455
+ )
456
+ self.ff_out._is_residual = True # type: ignore
457
+
458
+ # Rotary embeddings.
459
+ if self.config.rope:
460
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
461
+
462
+ self.flash_attn_func = None
463
+ if config.flash_attention:
464
+ try:
465
+ from flash_attn import flash_attn_func # type: ignore
466
+
467
+ self.flash_attn_func = flash_attn_func
468
+ except ModuleNotFoundError:
469
+ pass
470
+
471
+ def reset_parameters(self):
472
+ if self.k_norm is not None:
473
+ self.k_norm.reset_parameters()
474
+ if self.q_norm is not None:
475
+ self.q_norm.reset_parameters()
476
+ init_weights(
477
+ self.config,
478
+ self.attn_out,
479
+ d=self.config.d_model,
480
+ layer_id=self.layer_id,
481
+ type_of_module=ModuleType.out_module,
482
+ )
483
+ init_weights(
484
+ self.config,
485
+ self.ff_out,
486
+ d=self.ff_out.in_features,
487
+ layer_id=self.layer_id,
488
+ type_of_module=ModuleType.out_module,
489
+ )
490
+
491
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
492
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
493
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
494
+ else:
495
+ self._activation_checkpoint_fn = None
496
+
497
+ @classmethod
498
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
499
+ target_dtype = input_dtype
500
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
501
+ # `is_autocast_cpu_enabled()` for CPU autocast.
502
+ # See https://github.com/pytorch/pytorch/issues/110966.
503
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
504
+ target_dtype = torch.get_autocast_gpu_dtype()
505
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
506
+ target_dtype = torch.get_autocast_cpu_dtype()
507
+ if bias.dtype != target_dtype:
508
+ bias = bias.to(target_dtype)
509
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
510
+ return bias
511
+
512
+ def _scaled_dot_product_attention(
513
+ self,
514
+ q: torch.Tensor,
515
+ k: torch.Tensor,
516
+ v: torch.Tensor,
517
+ attn_mask: Optional[torch.Tensor] = None,
518
+ dropout_p: float = 0.0,
519
+ is_causal: bool = False,
520
+ ) -> torch.Tensor:
521
+ """
522
+ Computes scaled dot product attention on query, key and value tensors, using an optional
523
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
524
+ """
525
+ if self.flash_attn_func is not None and attn_mask is None:
526
+ r = self.flash_attn_func(
527
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal
528
+ )
529
+ return r.transpose(1, 2)
530
+ else:
531
+ # torch's sdpa doesn't support GQA, so we're doing this
532
+ assert k.size(1) == v.size(1)
533
+ num_kv_heads = k.size(1)
534
+ num_q_heads = q.size(1)
535
+ if num_q_heads != num_kv_heads:
536
+ assert num_q_heads % num_kv_heads == 0
537
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
538
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
539
+
540
+ return F.scaled_dot_product_attention(
541
+ q,
542
+ k,
543
+ v,
544
+ attn_mask=attn_mask,
545
+ dropout_p=dropout_p,
546
+ is_causal=is_causal,
547
+ )
548
+
549
+ def attention(
550
+ self,
551
+ q: torch.Tensor,
552
+ k: torch.Tensor,
553
+ v: torch.Tensor,
554
+ attention_bias: Optional[torch.Tensor] = None,
555
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
556
+ use_cache: bool = False,
557
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
558
+ B, T, C = q.size() # batch size, sequence length, d_model
559
+ dtype = k.dtype
560
+
561
+ # Optionally apply layer norm to keys and queries.
562
+ if self.q_norm is not None and self.k_norm is not None:
563
+ q = self.q_norm(q).to(dtype=dtype)
564
+ k = self.k_norm(k).to(dtype=dtype)
565
+
566
+ # Move head forward to be next to the batch dim.
567
+ # shape: (B, nh, T, hs)
568
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
569
+ # shape: (B, n_kv_h, T, hs)
570
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
571
+ # shape: (B, n_kv_h, T, hs)
572
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
573
+
574
+ if layer_past is not None:
575
+ past_key, past_value = layer_past
576
+ k = torch.cat((past_key, k), dim=-2)
577
+ v = torch.cat((past_value, v), dim=-2)
578
+
579
+ present = (k, v) if use_cache else None
580
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
581
+
582
+ if self.config.rope:
583
+ # Apply rotary embeddings.
584
+ q, k = self.rotary_emb(q, k)
585
+
586
+ if attention_bias is not None:
587
+ # Resize and cast attention bias.
588
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
589
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
590
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
591
+ # cause the SDP attn function to produce NaNs.
592
+ attention_bias = self._cast_attn_bias(
593
+ attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
594
+ )
595
+
596
+ # Get the attention scores.
597
+ # shape: (B, nh, T, hs)
598
+ att = self._scaled_dot_product_attention(
599
+ q,
600
+ k,
601
+ v,
602
+ attn_mask=attention_bias,
603
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
604
+ is_causal=attention_bias is None,
605
+ )
606
+
607
+ # Re-assemble all head outputs side-by-side.
608
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
609
+
610
+ # Apply output projection.
611
+ return self.attn_out(att), present
612
+
613
+ @abstractmethod
614
+ def forward(
615
+ self,
616
+ x: torch.Tensor,
617
+ attention_bias: Optional[torch.FloatTensor] = None,
618
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
619
+ use_cache: bool = False,
620
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
621
+ raise NotImplementedError
622
+
623
+ @classmethod
624
+ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBlock:
625
+ if config.block_type == BlockType.sequential:
626
+ return OLMoSequentialBlock(layer_id, config, cache)
627
+ elif config.block_type == BlockType.llama:
628
+ return OLMoLlamaBlock(layer_id, config, cache)
629
+ else:
630
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
631
+
632
+
633
+ class OLMoSequentialBlock(OLMoBlock):
634
+ """
635
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
636
+ (plus another skip connection).
637
+ """
638
+
639
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
640
+ super().__init__(layer_id, config, cache)
641
+ # Layer norms.
642
+ self.attn_norm = LayerNorm.build(config)
643
+ self.ff_norm = LayerNorm.build(config)
644
+ # Attention input projection. Projects x -> (q, k, v)
645
+
646
+ head_dim = config.d_model // config.n_heads
647
+ self.fused_dims = (
648
+ config.d_model,
649
+ config.effective_n_kv_heads * head_dim,
650
+ config.effective_n_kv_heads * head_dim,
651
+ )
652
+ self.att_proj = nn.Linear(
653
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
654
+ )
655
+ # Feed-forward input projection.
656
+ self.ff_proj = nn.Linear(
657
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
658
+ )
659
+
660
+ def reset_parameters(self):
661
+ super().reset_parameters()
662
+ self.attn_norm.reset_parameters()
663
+ self.ff_norm.reset_parameters()
664
+ # NOTE: the standard deviation for these weights does not depend on the layer.
665
+ init_weights(
666
+ self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
667
+ )
668
+ init_weights(
669
+ self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
670
+ )
671
+
672
+ def forward(
673
+ self,
674
+ x: torch.Tensor,
675
+ attention_bias: Optional[torch.Tensor] = None,
676
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
677
+ use_cache: bool = False,
678
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
679
+ # Get query, key, value projections.
680
+ # shape:
681
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
682
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
683
+ # k, v: (batch_size, seq_len, d_model // n_heads)
684
+ # - for group query attn q: (batch_size, seq_len, d_model)
685
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
686
+ if self._activation_checkpoint_fn is not None:
687
+ qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x))
688
+ else:
689
+ qkv = self.att_proj(self.attn_norm(x))
690
+
691
+ if self.config.clip_qkv is not None:
692
+ qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
693
+
694
+ q, k, v = qkv.split(self.fused_dims, dim=-1)
695
+
696
+ # Get attention scores.
697
+ if self._activation_checkpoint_fn is not None:
698
+ att, cache = self._activation_checkpoint_fn( # type: ignore
699
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
700
+ )
701
+ else:
702
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
703
+
704
+ # Add attention scores.
705
+ # shape: (B, T, C)
706
+ x = x + self.dropout(att)
707
+
708
+ # Add feed-forward projection.
709
+ # shape: (batch_size, seq_len, d_model)
710
+ og_x = x
711
+ if self._activation_checkpoint_fn is not None:
712
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
713
+ else:
714
+ x = self.ff_norm(x)
715
+ x = self.ff_proj(x)
716
+ if self._activation_checkpoint_fn is not None:
717
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
718
+ else:
719
+ x = self.act(x)
720
+ x = self.ff_out(x)
721
+ x = self.dropout(x)
722
+ x = og_x + x
723
+
724
+ return x, cache
725
+
726
+
727
+ class OLMoLlamaBlock(OLMoBlock):
728
+ """
729
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
730
+ (plus another skip connection). This block is similar to `OLMoSequentialBlock`
731
+ but some operations have slightly different implementations to imitate the
732
+ behavior of Llama.
733
+ """
734
+
735
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
736
+ super().__init__(layer_id, config, cache)
737
+ # Layer norms.
738
+ self.attn_norm = LayerNorm.build(config)
739
+ self.ff_norm = LayerNorm.build(config)
740
+ self.__cache = cache
741
+
742
+ # Attention input projection. Projects x -> (q, k, v)
743
+ if config.multi_query_attention:
744
+ q_proj_out_dim = config.d_model
745
+ k_proj_out_dim = config.d_model // config.n_heads
746
+ v_proj_out_dim = config.d_model // config.n_heads
747
+ else:
748
+ q_proj_out_dim = config.d_model
749
+ k_proj_out_dim = config.d_model
750
+ v_proj_out_dim = config.d_model
751
+ self.q_proj = nn.Linear(
752
+ config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device
753
+ )
754
+ self.k_proj = nn.Linear(
755
+ config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device
756
+ )
757
+ self.v_proj = nn.Linear(
758
+ config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device
759
+ )
760
+
761
+ # Feed-forward input projection.
762
+ self.ff_proj = nn.Linear(
763
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
764
+ )
765
+
766
+ def reset_parameters(self):
767
+ super().reset_parameters()
768
+ self.attn_norm.reset_parameters()
769
+ self.ff_norm.reset_parameters()
770
+ # NOTE: the standard deviation for these weights does not depend on the layer.
771
+ init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
772
+ init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
773
+ init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
774
+ init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
775
+
776
+ def _scaled_dot_product_attention(
777
+ self,
778
+ q: torch.Tensor,
779
+ k: torch.Tensor,
780
+ v: torch.Tensor,
781
+ attn_mask: Optional[torch.Tensor] = None,
782
+ dropout_p: float = 0.0,
783
+ is_causal: bool = False,
784
+ ) -> torch.Tensor:
785
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
786
+
787
+ if is_causal:
788
+ assert attn_mask is None
789
+
790
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
791
+ attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
792
+ elif attn_mask is not None:
793
+ attn_bias = attn_mask.to(q.dtype)
794
+ else:
795
+ attn_bias = torch.zeros_like(attn_weights)
796
+
797
+ attn_weights += attn_bias
798
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
799
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
800
+ return torch.matmul(attn_weights, v)
801
+
802
+ def forward(
803
+ self,
804
+ x: torch.Tensor,
805
+ attention_bias: Optional[torch.Tensor] = None,
806
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
807
+ use_cache: bool = False,
808
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
809
+ # Get query, key, value projections.
810
+ # shape:
811
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
812
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
813
+ # k, v: (batch_size, seq_len, d_model // n_heads)
814
+ x_normed = self.attn_norm(x)
815
+ q = self.q_proj(x_normed)
816
+ k = self.k_proj(x_normed)
817
+ v = self.v_proj(x_normed)
818
+
819
+ if self.config.clip_qkv is not None:
820
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
821
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
822
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
823
+
824
+ # Get attention scores.
825
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
826
+
827
+ # Add attention scores.
828
+ # shape: (B, T, C)
829
+ x = x + self.dropout(att)
830
+
831
+ # Add feed-forward projection.
832
+ # shape: (batch_size, seq_len, d_model)
833
+ og_x = x
834
+ if self._activation_checkpoint_fn is not None:
835
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
836
+ else:
837
+ x = self.ff_norm(x)
838
+ x = self.ff_proj(x)
839
+ if self._activation_checkpoint_fn is not None:
840
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
841
+ else:
842
+ x = self.act(x)
843
+ x = self.ff_out(x)
844
+ x = self.dropout(x)
845
+ x = og_x + x
846
+
847
+ return x, cache
848
+
849
+
850
+ class OLMoOutput(NamedTuple):
851
+ logits: torch.FloatTensor
852
+ """
853
+ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
854
+ for the next token *before* normalization via (log) softmax.
855
+ """
856
+
857
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
858
+ """
859
+ Attention keys and values from each block.
860
+ """
861
+
862
+ hidden_states: Optional[Tuple[torch.Tensor]]
863
+ """
864
+ Hidden states from each block.
865
+ """
866
+
867
+
868
+ class OLMoGenerateOutput(NamedTuple):
869
+ token_ids: torch.LongTensor
870
+ """
871
+ The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
872
+ These do *not* include the original input IDs.
873
+ """
874
+
875
+ scores: torch.FloatTensor
876
+ """
877
+ The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
878
+ """
879
+
880
+
881
+ class OLMoBlockGroup(nn.ModuleList):
882
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
883
+ super().__init__(modules)
884
+ self.config = config
885
+ self.layer_offset = layer_offset
886
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
887
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
888
+
889
+ def forward(
890
+ self,
891
+ x: torch.Tensor,
892
+ attention_bias: Optional[torch.FloatTensor] = None,
893
+ layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
894
+ use_cache: bool = False,
895
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
896
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
897
+ for block_idx, block in enumerate(self):
898
+ layer_past = None if layers_past is None else layers_past[block_idx]
899
+ block_idx += self.layer_offset
900
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
901
+ # shape: (batch_size, seq_len, d_model)
902
+ x, cache = self._activation_checkpoint_fn( # type: ignore
903
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
904
+ )
905
+ else:
906
+ # shape: (batch_size, seq_len, d_model)
907
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
908
+ if attn_key_values is not None:
909
+ assert cache is not None
910
+ attn_key_values.append(cache)
911
+ return x, attn_key_values
912
+
913
+ def reset_parameters(self):
914
+ for block in self:
915
+ block.reset_parameters()
916
+
917
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
918
+ self.activation_checkpointing_strategy = strategy
919
+ for block in self:
920
+ block.set_activation_checkpointing(strategy)
921
+
922
+
923
+ class OLMo(nn.Module):
924
+ def __init__(self, config: ModelConfig, init_params: bool = True):
925
+ super().__init__()
926
+ self.config = config
927
+ self.__cache = BufferCache()
928
+
929
+ # Validate config.
930
+ if self.config.alibi and self.config.flash_attention:
931
+ raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
932
+
933
+ if self.config.alibi and self.config.rope:
934
+ raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
935
+
936
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
937
+ if self.config.embedding_size < self.config.vocab_size:
938
+ raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
939
+ elif self.config.embedding_size % 128 != 0:
940
+ import warnings
941
+
942
+ warnings.warn(
943
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
944
+ )
945
+
946
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
947
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
948
+
949
+ if not (
950
+ 0 < self.config.block_group_size <= self.config.n_layers
951
+ and self.config.n_layers % self.config.block_group_size == 0
952
+ ):
953
+ raise OLMoConfigurationError("n layers must be divisible by block group size")
954
+
955
+ torch.backends.cuda.enable_flash_sdp(True)
956
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
957
+
958
+ self.transformer = nn.ModuleDict(
959
+ dict(
960
+ wte=nn.Embedding(
961
+ config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
962
+ ),
963
+ emb_drop=Dropout(config.embedding_dropout),
964
+ ln_f=LayerNorm.build(config),
965
+ )
966
+ )
967
+
968
+ blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
969
+ if self.config.block_group_size > 1:
970
+ block_groups = [
971
+ OLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
972
+ for i in range(0, config.n_layers, config.block_group_size)
973
+ ]
974
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
975
+ else:
976
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
977
+
978
+ if not (self.config.alibi or self.config.rope):
979
+ self.transformer.update(
980
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
981
+ )
982
+ if not config.weight_tying:
983
+ self.transformer.update(
984
+ {
985
+ "ff_out": nn.Linear(
986
+ config.d_model,
987
+ config.embedding_size or config.vocab_size,
988
+ bias=config.include_bias,
989
+ device=config.init_device,
990
+ )
991
+ }
992
+ )
993
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
994
+ if init_params and self.config.init_device != "meta":
995
+ self.reset_parameters()
996
+ self.__num_fwd_flops: Optional[int] = None
997
+
998
+ # Warm up cache.
999
+ if self.config.alibi:
1000
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1001
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1002
+
1003
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1004
+ self.activation_checkpointing_strategy = strategy
1005
+ if self.config.block_group_size != 1:
1006
+ for block_group in self.transformer.block_groups:
1007
+ block_group.set_activation_checkpointing(strategy)
1008
+ else:
1009
+ for block in self.transformer.blocks:
1010
+ block.set_activation_checkpointing(strategy)
1011
+
1012
+ @property
1013
+ def device(self) -> torch.device:
1014
+ device: torch.device = self.transformer.wte.weight.device # type: ignore
1015
+ if device.type == "meta":
1016
+ return _non_meta_init_device(self.config)
1017
+ else:
1018
+ return device
1019
+
1020
+ def reset_parameters(self):
1021
+ log.info("Initializing model parameters...")
1022
+ # Top-level embeddings / linear layers.
1023
+ init_weights(
1024
+ self.config,
1025
+ self.transformer.wte, # type: ignore
1026
+ std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
1027
+ type_of_module=ModuleType.emb,
1028
+ )
1029
+ if hasattr(self.transformer, "wpe"):
1030
+ init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
1031
+
1032
+ # Top-level layer norm.
1033
+ self.transformer.ln_f.reset_parameters() # type: ignore
1034
+
1035
+ # Output weights.
1036
+ if hasattr(self.transformer, "ff_out"):
1037
+ init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
1038
+
1039
+ # Let the blocks handle themselves.
1040
+ if self.config.block_group_size == 1:
1041
+ for block in self.transformer.blocks:
1042
+ block.reset_parameters()
1043
+ else:
1044
+ for block_group in self.transformer.block_groups:
1045
+ block_group.reset_parameters()
1046
+
1047
+ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1048
+ if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1049
+ -1
1050
+ ] >= seq_len:
1051
+ if alibi_bias.device != device:
1052
+ alibi_bias = alibi_bias.to(device)
1053
+ self.__cache["alibi_attention_bias"] = alibi_bias
1054
+ return alibi_bias
1055
+ with torch.autocast(device.type, enabled=False):
1056
+ alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1057
+ self.__cache["alibi_attention_bias"] = alibi_bias
1058
+ return alibi_bias
1059
+
1060
+ def forward(
1061
+ self,
1062
+ input_ids: torch.LongTensor,
1063
+ input_embeddings: Optional[torch.FloatTensor] = None,
1064
+ attention_mask: Optional[torch.Tensor] = None,
1065
+ attention_bias: Optional[torch.Tensor] = None,
1066
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1067
+ use_cache: bool = False,
1068
+ last_logits_only: bool = False,
1069
+ output_hidden_states: Optional[bool] = None,
1070
+ ) -> OLMoOutput:
1071
+ """
1072
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1073
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1074
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1075
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1076
+ which input IDs are masked. A `1` value in the mask means that
1077
+ the corresponding input ID should *not* be ignored. A `0` means
1078
+ that the corresponding input ID is masked.
1079
+
1080
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1081
+ library.
1082
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1083
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1084
+ to introduce causal or other biases.
1085
+
1086
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1087
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1088
+ element in the sequence.
1089
+
1090
+ If the tensor is a float tensor, it will just be added to the attention
1091
+ scores before the softmax.
1092
+
1093
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1094
+ :param past_key_values: Pre-computed keys and values for each attention block.
1095
+ Can be used to speed up sequential decoding. The `input_ids` which have
1096
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1097
+ :param use_cache: If `True`, return key and value tensors for each block.
1098
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1099
+ This can speed up decoding when you only care about the next token.
1100
+ """
1101
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1102
+
1103
+ if past_key_values:
1104
+ assert len(past_key_values) == self.config.n_layers
1105
+
1106
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1107
+ if past_key_values is None:
1108
+ past_length = 0
1109
+ else:
1110
+ past_length = past_key_values[0][0].size(-2)
1111
+
1112
+ # Get embeddings of input.
1113
+ # shape: (batch_size, seq_len, d_model)
1114
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
1115
+
1116
+ if not (self.config.alibi or self.config.rope):
1117
+ # Get positional embeddings.
1118
+ # shape: (1, seq_len)
1119
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1120
+ # shape: (1, seq_len, d_model)
1121
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1122
+ x = pos_emb + x
1123
+
1124
+ # Add input + positional embeddings and apply dropout.
1125
+ # shape: (batch_size, seq_len, d_model)
1126
+ x = self.transformer.emb_drop(x) # type: ignore
1127
+
1128
+ # Transform the attention mask into what the blocks expect.
1129
+ if attention_mask is not None:
1130
+ # shape: (batch_size, 1, 1, seq_len)
1131
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1132
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1133
+
1134
+ # Merge attention mask with attention bias.
1135
+ if (
1136
+ attention_bias is not None
1137
+ or attention_mask is not None
1138
+ or self.config.alibi
1139
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1140
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1141
+ # scores correctly.
1142
+ or past_key_values is not None
1143
+ ):
1144
+ if attention_bias is None and self.config.alibi:
1145
+ attention_bias = get_causal_attention_bias(
1146
+ self.__cache, past_length + seq_len, x.device
1147
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1148
+ elif attention_bias is None:
1149
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1150
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1151
+ attention_bias = attention_bias.to(dtype=torch.float)
1152
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1153
+
1154
+ # Transform to the right shape and data type.
1155
+ mask_len = seq_len
1156
+ if attention_mask is not None:
1157
+ mask_len = attention_mask.shape[-1]
1158
+ elif past_key_values is not None:
1159
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1160
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1161
+
1162
+ # Add in the masking bias.
1163
+ if attention_mask is not None:
1164
+ attention_bias = attention_bias + attention_mask
1165
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1166
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1167
+ # it can produce NaNs.
1168
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1169
+
1170
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1171
+
1172
+ # decoder layers
1173
+ all_hidden_states = []
1174
+
1175
+ # Apply blocks one-by-one.
1176
+ if self.config.block_group_size == 1:
1177
+ for block_idx, block in enumerate(self.transformer.blocks):
1178
+ if output_hidden_states:
1179
+ # add hidden states
1180
+ all_hidden_states.append(x)
1181
+
1182
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1183
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
1184
+ # shape: (batch_size, seq_len, d_model)
1185
+ x, cache = self._activation_checkpoint_fn(
1186
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1187
+ )
1188
+ else:
1189
+ # shape: (batch_size, seq_len, d_model)
1190
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1191
+
1192
+ if attn_key_values is not None:
1193
+ assert cache is not None
1194
+ attn_key_values.append(cache)
1195
+ else:
1196
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1197
+ if output_hidden_states:
1198
+ # add hidden states
1199
+ all_hidden_states.append(x)
1200
+
1201
+ layers_past = (
1202
+ None
1203
+ if past_key_values is None
1204
+ else past_key_values[
1205
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1206
+ ]
1207
+ )
1208
+ x, cache = block_group(
1209
+ x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
1210
+ )
1211
+ if attn_key_values is not None:
1212
+ assert cache is not None
1213
+ attn_key_values.extend(cache)
1214
+
1215
+ if last_logits_only:
1216
+ # shape: (batch_size, 1, d_model)
1217
+ x = x[:, -1, :].unsqueeze(1)
1218
+
1219
+ # Apply final layer norm.
1220
+ # shape: (batch_size, seq_len or 1, d_model)
1221
+ x = self.transformer.ln_f(x) # type: ignore
1222
+ if output_hidden_states:
1223
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1224
+ all_hidden_states.append(x)
1225
+
1226
+ # Get logits.
1227
+ # shape: (batch_size, seq_len or 1, vocab_size)
1228
+ if self.config.weight_tying:
1229
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1230
+ else:
1231
+ logits = self.transformer.ff_out(x) # type: ignore
1232
+ if self.config.scale_logits:
1233
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1234
+
1235
+ return OLMoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
1236
+
1237
+ def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
1238
+ if wrap_strategy is None:
1239
+ return None
1240
+
1241
+ # The 'recurse' mode for the wrap function does not behave like you'd expect.
1242
+ # Even if we return False, it may still recurse because PyTorch does what it wants,
1243
+ # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
1244
+ # but not other linear layers within a block.
1245
+ # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
1246
+ # return True in 'recurse' mode for simplicity.
1247
+ size_based_module_to_wrap = {self.transformer.wte}
1248
+ if hasattr(self.transformer, "ff_out"):
1249
+ size_based_module_to_wrap.add(self.transformer.ff_out)
1250
+
1251
+ if wrap_strategy == FSDPWrapStrategy.by_block:
1252
+
1253
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1254
+ del nonwrapped_numel
1255
+ wrap = isinstance(module, OLMoBlock)
1256
+ if recurse:
1257
+ return True
1258
+ else:
1259
+ return wrap
1260
+
1261
+ return fsdp_wrap_fn
1262
+ elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
1263
+
1264
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1265
+ del nonwrapped_numel
1266
+ wrap = isinstance(module, (OLMoBlock,)) or module in size_based_module_to_wrap
1267
+ if recurse:
1268
+ return True
1269
+ else:
1270
+ return wrap
1271
+
1272
+ return fsdp_wrap_fn
1273
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group:
1274
+ if self.config.block_group_size <= 1:
1275
+ raise OLMoConfigurationError(
1276
+ "'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
1277
+ )
1278
+
1279
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1280
+ del nonwrapped_numel
1281
+ wrap = isinstance(module, OLMoBlockGroup)
1282
+ if recurse:
1283
+ return True
1284
+ else:
1285
+ return wrap
1286
+
1287
+ return fsdp_wrap_fn
1288
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
1289
+ if self.config.block_group_size <= 1:
1290
+ raise OLMoConfigurationError(
1291
+ "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
1292
+ )
1293
+
1294
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1295
+ del nonwrapped_numel
1296
+ wrap = isinstance(module, (OLMoBlockGroup,)) or module in size_based_module_to_wrap
1297
+ if recurse:
1298
+ return True
1299
+ else:
1300
+ return wrap
1301
+
1302
+ return fsdp_wrap_fn
1303
+ elif wrap_strategy == FSDPWrapStrategy.size_based:
1304
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
1305
+
1306
+ return size_based_auto_wrap_policy
1307
+ elif wrap_strategy in {
1308
+ FSDPWrapStrategy.one_in_two,
1309
+ FSDPWrapStrategy.one_in_three,
1310
+ FSDPWrapStrategy.one_in_four,
1311
+ FSDPWrapStrategy.one_in_five,
1312
+ }:
1313
+ c = {
1314
+ FSDPWrapStrategy.one_in_two: 2,
1315
+ FSDPWrapStrategy.one_in_three: 3,
1316
+ FSDPWrapStrategy.one_in_four: 4,
1317
+ FSDPWrapStrategy.one_in_five: 5,
1318
+ }[wrap_strategy]
1319
+
1320
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1321
+ del nonwrapped_numel
1322
+ wrap = isinstance(module, OLMoBlock) and module.layer_id % c == 0
1323
+ if recurse:
1324
+ return True
1325
+ else:
1326
+ return wrap
1327
+
1328
+ return fsdp_wrap_fn
1329
+ else:
1330
+ raise NotImplementedError(wrap_strategy)
1331
+
1332
+ def num_params(self, include_embedding: bool = True) -> int:
1333
+ """
1334
+ Get the total number of parameters.
1335
+ """
1336
+ params = (np for np in self.named_parameters())
1337
+ if not include_embedding:
1338
+ params = filter( # type: ignore
1339
+ lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
1340
+ params,
1341
+ )
1342
+ return sum(p.numel() for _, p in params)
1343
+
1344
+ @property
1345
+ def num_fwd_flops(self):
1346
+ if self.__num_fwd_flops:
1347
+ return self.__num_fwd_flops
1348
+ n_params = self.num_params()
1349
+ # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
1350
+ # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
1351
+ # this gets us FLOPs / token
1352
+ params_flops_per_token = 2 * n_params
1353
+ params_flops_per_seq = params_flops_per_token * self.config.max_sequence_length
1354
+ # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
1355
+ attn_flops_per_seq = (
1356
+ self.config.n_layers * 2 * 2 * (self.config.d_model * (self.config.max_sequence_length**2))
1357
+ )
1358
+ self.__num_fwd_flops = params_flops_per_seq + attn_flops_per_seq
1359
+ return self.__num_fwd_flops
1360
+
1361
+ def generate(
1362
+ self,
1363
+ input_ids: torch.LongTensor,
1364
+ attention_mask: Optional[torch.Tensor] = None,
1365
+ attention_bias: Optional[torch.Tensor] = None,
1366
+ max_steps: int = 10,
1367
+ beam_size: int = 1,
1368
+ per_node_beam_size: Optional[int] = None,
1369
+ sampler: Optional[Sampler] = None,
1370
+ min_steps: Optional[int] = None,
1371
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
1372
+ constraints: Optional[List[Constraint]] = None,
1373
+ ) -> OLMoGenerateOutput:
1374
+ """
1375
+ Generate token IDs using beam search.
1376
+
1377
+ Note that by default ``beam_size`` is set to 1, which is greedy decoding.
1378
+
1379
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1380
+ :param attention_mask: A optional tensor of shape `(batch_size, seq_len)`, the same
1381
+ as for the forward method.
1382
+ :param attention_bias: A tensor of shape
1383
+ `(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`,
1384
+ the same as for the forward method except only one shape is excepted here.
1385
+
1386
+ For an explanation of the other arguments, see :class:`BeamSearch`.
1387
+ """
1388
+ beam_search = BeamSearch(
1389
+ self.config.eos_token_id,
1390
+ max_steps=max_steps,
1391
+ beam_size=beam_size,
1392
+ per_node_beam_size=per_node_beam_size,
1393
+ sampler=sampler,
1394
+ min_steps=min_steps,
1395
+ final_sequence_scorer=final_sequence_scorer,
1396
+ constraints=constraints,
1397
+ )
1398
+
1399
+ # Validate inputs.
1400
+ batch_size, seq_len = input_ids.shape
1401
+ if attention_mask is not None:
1402
+ assert attention_mask.shape == (batch_size, seq_len)
1403
+ if attention_bias is not None:
1404
+ assert len(attention_bias.shape) == 4
1405
+ assert attention_bias.shape[:2] == (batch_size, 1)
1406
+ assert (
1407
+ seq_len + beam_search.max_steps
1408
+ <= attention_bias.shape[2]
1409
+ == attention_bias.shape[3]
1410
+ <= self.config.max_sequence_length
1411
+ )
1412
+
1413
+ tokens_generated = 0
1414
+
1415
+ def flatten_past_key_values(
1416
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
1417
+ ) -> Dict[str, torch.Tensor]:
1418
+ out = {}
1419
+ for i, (key, value) in enumerate(past_key_values):
1420
+ out[f"past_key_{i}"] = key
1421
+ out[f"past_value_{i}"] = value
1422
+ return out
1423
+
1424
+ def unflatten_past_key_values(
1425
+ past_key_values: Dict[str, torch.Tensor],
1426
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
1427
+ out = []
1428
+ for i in range(self.config.n_layers):
1429
+ past_key = past_key_values[f"past_key_{i}"]
1430
+ past_value = past_key_values[f"past_value_{i}"]
1431
+ out.append((past_key, past_value))
1432
+ return out
1433
+
1434
+ def step(
1435
+ last_predictions: torch.Tensor, state: dict[str, torch.Tensor]
1436
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1437
+ nonlocal tokens_generated
1438
+
1439
+ attention_mask = state.get("attention_mask")
1440
+ attention_bias = state.get("attention_bias")
1441
+
1442
+ if tokens_generated > 0:
1443
+ past_key_values = unflatten_past_key_values(state)
1444
+ input_ids = last_predictions.unsqueeze(1)
1445
+ if attention_mask is not None:
1446
+ group_size = input_ids.shape[0]
1447
+ attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1)
1448
+ else:
1449
+ past_key_values = None
1450
+ input_ids = state["input_ids"]
1451
+
1452
+ tokens_generated += 1
1453
+
1454
+ # Run forward pass of model to get logits, then normalize to get log probs.
1455
+ output = self(
1456
+ input_ids,
1457
+ attention_mask=attention_mask,
1458
+ attention_bias=attention_bias,
1459
+ past_key_values=past_key_values,
1460
+ use_cache=True,
1461
+ last_logits_only=True,
1462
+ )
1463
+ log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1)
1464
+
1465
+ # Create new state.
1466
+ state = flatten_past_key_values(output.attn_key_values)
1467
+ if attention_mask is not None:
1468
+ state["attention_mask"] = attention_mask
1469
+ if attention_bias is not None:
1470
+ state["attention_bias"] = attention_bias
1471
+
1472
+ return log_probs, state
1473
+
1474
+ initial_preds = input_ids.new_zeros((batch_size,)) # This is arbitrary, we won't use this.
1475
+ state: dict[str, torch.Tensor] = {"input_ids": input_ids}
1476
+ if attention_mask is not None:
1477
+ state["attention_mask"] = attention_mask
1478
+ if attention_bias is not None:
1479
+ state["attention_bias"] = attention_bias
1480
+ with torch.no_grad():
1481
+ token_ids, scores = beam_search.search(initial_preds, state, step)
1482
+
1483
+ return OLMoGenerateOutput(
1484
+ token_ids=token_ids, # type: ignore[arg-type]
1485
+ scores=scores, # type: ignore[arg-type]
1486
+ )
1487
+
1488
+ @classmethod
1489
+ def from_checkpoint(
1490
+ cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: Optional[CheckpointType] = None
1491
+ ) -> OLMo:
1492
+ """
1493
+ Load an OLMo model from a checkpoint.
1494
+ """
1495
+ from .util import resource_path
1496
+
1497
+ # Guess checkpoint type.
1498
+ if checkpoint_type is None:
1499
+ try:
1500
+ if resource_path(checkpoint_dir, "model.pt").is_file():
1501
+ checkpoint_type = CheckpointType.unsharded
1502
+ else:
1503
+ checkpoint_type = CheckpointType.sharded
1504
+ except FileNotFoundError:
1505
+ checkpoint_type = CheckpointType.sharded
1506
+
1507
+ # Load config.
1508
+ config_path = resource_path(checkpoint_dir, "config.yaml")
1509
+ model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
1510
+
1511
+ if checkpoint_type == CheckpointType.unsharded:
1512
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1513
+ model_config.init_device = "cpu"
1514
+ model = OLMo(model_config)
1515
+
1516
+ # Load state dict directly to target device.
1517
+ state_dict_path = resource_path(checkpoint_dir, "model.pt")
1518
+ state_dict = torch.load(state_dict_path, map_location="cpu")
1519
+ model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
1520
+ model = model.to(torch.device(device))
1521
+ else:
1522
+ from .checkpoint import load_model_state
1523
+
1524
+ # Initialize model on target device. In this case the state dict is loaded in-place
1525
+ # so it's not necessary to start on CPU if the target device is a GPU.
1526
+ model_config.init_device = device
1527
+ model = OLMo(model_config)
1528
+
1529
+ # Load state dict in place.
1530
+ load_model_state(checkpoint_dir, model)
1531
+
1532
+ return model.eval()
1533
+
1534
+ def _make_state_dict_compatible(
1535
+ self, state_dict: Dict[str, torch.Tensor]
1536
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]:
1537
+ """
1538
+ Handles some cases where the state dict is valid yet may need to be transformed in order to
1539
+ be loaded.
1540
+
1541
+ This modifies the state dict in-place and also returns it, along with a mapping of original key
1542
+ names to new key names in cases where the keys were simply renamed. That mapping can be used
1543
+ to make a corresponding optimizer state dict compatible as well.
1544
+ """
1545
+ import re
1546
+ from fnmatch import fnmatch
1547
+
1548
+ new_keys_to_og_keys: Dict[str, str] = {}
1549
+
1550
+ # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
1551
+ # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
1552
+ # fine without the prefixes. This also simplifies the other steps below.
1553
+ for key in list(state_dict.keys()):
1554
+ state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
1555
+ new_keys_to_og_keys[new_key] = key
1556
+
1557
+ # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
1558
+ if self.config.block_type == BlockType.sequential:
1559
+ for key in list(state_dict.keys()):
1560
+ if fnmatch(key, "transformer.*.norm.weight"):
1561
+ tensor = state_dict.pop(key)
1562
+ state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
1563
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1564
+ state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
1565
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1566
+ del new_keys_to_og_keys[key]
1567
+ elif fnmatch(key, "transformer.*.norm.bias"):
1568
+ tensor = state_dict.pop(key)
1569
+ state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
1570
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1571
+ state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
1572
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1573
+ del new_keys_to_og_keys[key]
1574
+
1575
+ # For loading a state dict that was saved with a different `block_group_size`.
1576
+ if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
1577
+ state_dict_block_group_size = len(
1578
+ [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
1579
+ )
1580
+ else:
1581
+ state_dict_block_group_size = 1
1582
+ if self.config.block_group_size != state_dict_block_group_size:
1583
+ log.info(
1584
+ f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
1585
+ f"group size {self.config.block_group_size}"
1586
+ )
1587
+ # For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
1588
+ # and then (re-)group them into the right block sizes.
1589
+ if state_dict_block_group_size > 1:
1590
+ for key in list(state_dict.keys()):
1591
+ if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
1592
+ group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
1593
+ block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
1594
+ state_dict[
1595
+ (
1596
+ new_key := key.replace(
1597
+ f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
1598
+ )
1599
+ )
1600
+ ] = state_dict.pop(key)
1601
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1602
+
1603
+ if self.config.block_group_size > 1:
1604
+ # Group the state dict blocks into the right block size.
1605
+ for key in list(state_dict.keys()):
1606
+ if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
1607
+ block_idx = int(m.group(1))
1608
+ group_idx, group_block_idx = (
1609
+ block_idx // self.config.block_group_size,
1610
+ block_idx % self.config.block_group_size,
1611
+ )
1612
+ state_dict[
1613
+ (
1614
+ new_key := key.replace(
1615
+ f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
1616
+ )
1617
+ )
1618
+ ] = state_dict.pop(key)
1619
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1620
+
1621
+ og_keys_to_new: Dict[str, Set[str]] = defaultdict(set)
1622
+ for new_key, og_key in new_keys_to_og_keys.items():
1623
+ og_keys_to_new[og_key].add(new_key)
1624
+
1625
+ return state_dict, og_keys_to_new
model/modeling_olmo.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import fields
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from transformers import PreTrainedModel
7
+ from transformers.cache_utils import Cache
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast, SequenceClassifierOutputWithPast
9
+ from transformers.models.auto import AutoModelForCausalLM, AutoModelForSequenceClassification
10
+
11
+ from .config import ModelConfig
12
+ from .model import OLMo
13
+ import sys
14
+ import os
15
+
16
+ # Add the parent directory to sys.path
17
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
18
+
19
+ from .configuration_olmo import OLMoConfig
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+
24
+ def create_model_config_from_pretrained_config(config: OLMoConfig, is_cls = False):
25
+ """
26
+ Utility function
27
+ """
28
+ kwargs = {}
29
+ for field in fields(ModelConfig):
30
+ kwargs[field.name] = getattr(config, field.name)
31
+ # add num_labels for being compatible with the AutoSeqClassification downstream task
32
+ model_config = ModelConfig(**kwargs)
33
+ if is_cls:
34
+ num_labels = len(getattr(config,'label2id'))
35
+ # print(f"{config}")
36
+ return model_config, num_labels
37
+ return model_config
38
+
39
+
40
+ class OLMoForCausalLM(PreTrainedModel):
41
+ """
42
+ Extremely barebones HF model wrapper.
43
+ """
44
+
45
+ config_class = OLMoConfig
46
+ base_model_prefix = "model"
47
+ _no_split_modules = ["OLMoBlock"]
48
+
49
+ def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
50
+ super().__init__(config)
51
+
52
+ if not model:
53
+ model_config = create_model_config_from_pretrained_config(config)
54
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
55
+ model_config.init_device = "cpu"
56
+ self.model = OLMo(model_config, init_params=init_params)
57
+ else:
58
+ self.model = model
59
+ self.word_embeddings = self.model.transformer.wte
60
+ def forward(
61
+ self,
62
+ input_ids: torch.LongTensor = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ attention_bias: Optional[torch.Tensor] = None,
66
+ token_type_ids: Optional[torch.LongTensor] = None, # Added parameter
67
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
68
+ labels: Optional[torch.LongTensor] = None,
69
+ use_cache: Optional[bool] = None,
70
+ output_attentions: Optional[bool] = None,
71
+ output_hidden_states: Optional[bool] = True,
72
+ return_dict: Optional[bool] = None,
73
+ cache_position: Optional[
74
+ Cache
75
+ ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
76
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
77
+ if use_cache is None:
78
+ use_cache = self.config.use_cache
79
+
80
+ if output_attentions:
81
+ raise ValueError("output_attentions is not yet supported in OLMo")
82
+
83
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
84
+
85
+ ######
86
+ # Create attention bias only if it's not provided for bidirectional finetuning
87
+ # Should only uncomment when performing MNTP finetuning
88
+ ######
89
+ # if attention_bias is None:
90
+ # seq_len = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
91
+ # attention_bias = self.get_bidirectional_attention_bias(seq_len=seq_len, device=input_ids.device)
92
+
93
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
94
+ outputs = self.model.forward(
95
+ input_ids=input_ids,
96
+ input_embeddings=inputs_embeds,
97
+ attention_mask=attention_mask,
98
+ attention_bias=attention_bias,
99
+ past_key_values=past_key_values,
100
+ use_cache=use_cache,
101
+ output_hidden_states=output_hidden_states,
102
+ )
103
+
104
+ logits = outputs.logits
105
+ hidden_states = outputs.hidden_states
106
+
107
+ loss = None
108
+ if labels is not None:
109
+ # Shift so that tokens < n predict n
110
+ shift_logits = logits[..., :-1, :].contiguous()
111
+ shift_labels = labels[..., 1:].contiguous()
112
+ # Flatten the tokens
113
+ loss_fct = torch.nn.CrossEntropyLoss()
114
+ shift_logits = shift_logits.view(-1, self.config.embedding_size)
115
+ shift_labels = shift_labels.view(-1)
116
+ # Enable model parallelism
117
+ shift_labels = shift_labels.to(shift_logits.device)
118
+ loss = loss_fct(shift_logits, shift_labels)
119
+
120
+ if not return_dict:
121
+ output = (logits,) + outputs[1:]
122
+ return (loss,) + output if loss is not None else output
123
+
124
+ return CausalLMOutputWithPast(
125
+ loss=loss,
126
+ logits=logits,
127
+ past_key_values=outputs.attn_key_values,
128
+ hidden_states=hidden_states,
129
+ )
130
+
131
+ def can_generate(self) -> bool:
132
+ return True
133
+
134
+ def get_bidirectional_attention_bias(self, seq_len: int, device: torch.device):
135
+ """
136
+ Create a bidirectional attention bias for full sequence attention.
137
+ The bias matrix will not restrict attention in any direction.
138
+ """
139
+ # Bias shape: (1, 1, seq_len, seq_len)
140
+ bias = torch.zeros(1, 1, seq_len, seq_len, device=device)
141
+ return bias
142
+
143
+ def prepare_inputs_for_generation(
144
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
145
+ ):
146
+ if past_key_values:
147
+ # This is because we want the model to only process the last generated token.
148
+ input_ids = input_ids[:, -1:]
149
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
150
+
151
+ model_inputs.update(kwargs)
152
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
153
+ return model_inputs
154
+
155
+ # TODO: these are required to make the implementation complete.
156
+ # def resize_position_embeddings(self, new_num_position_embeddings: int):
157
+ # pass
158
+ #
159
+ # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
160
+ # pass
161
+ #
162
+ # def _reorder_cache(self, past_key_values, beam_idx):
163
+ # pass
164
+
165
+ def get_input_embeddings(self) -> torch.nn.Module:
166
+ return self.model.transformer.wte
167
+
168
+ def set_input_embeddings(self, value: torch.nn.Module):
169
+ self.model.transformer.wte = value
170
+
171
+ def get_output_embeddings(self):
172
+ if self.config.weight_tying:
173
+ return self.model.transformer.wte
174
+ else:
175
+ return self.model.transformer.ff_out
176
+
177
+ def set_output_embeddings(self, value: torch.nn.Module):
178
+ if self.config.weight_tying:
179
+ self.model.transformer.wte = value
180
+ else:
181
+ self.model.transformer.ff_out = value
182
+
183
+ def tie_weights(self):
184
+ """
185
+ This function is intentionally left as a no-op.
186
+
187
+ Weight tying is handled as follows:
188
+ - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
189
+ See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
190
+ - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
191
+ See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
192
+
193
+ Therefore, there is no need to explicitly tie the weights in this function.
194
+ """
195
+ pass
196
+
197
+ def resize_token_embeddings(
198
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
199
+ ) -> torch.nn.Embedding:
200
+ """
201
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.
202
+
203
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
204
+
205
+ Arguments:
206
+ new_num_tokens (`int`, *optional*):
207
+ The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
208
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
209
+ returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
210
+ pad_to_multiple_of (`int`, *optional*):
211
+ If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
212
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
213
+
214
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
215
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
216
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
217
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
218
+
219
+ Return:
220
+ `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
221
+
222
+ Note:
223
+ This method differs from the base class implementation by resizing the `embedding_size` attribute of the
224
+ model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
225
+ is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
226
+ embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
227
+ """
228
+ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
229
+ if new_num_tokens is None and pad_to_multiple_of is None:
230
+ return model_embeds
231
+
232
+ # Update base model and current model config
233
+ self.config.embedding_size = model_embeds.weight.shape[0]
234
+ self.model.config.embedding_size = model_embeds.weight.shape[0]
235
+
236
+ # Check if the embedding size is less than the vocab size
237
+ if self.config.embedding_size < self.config.vocab_size:
238
+ warning_message = (
239
+ f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
240
+ f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
241
+ "size is less than or equal to the new token embedding size."
242
+ )
243
+ log.warning(warning_message)
244
+
245
+ # Tie weights again if needed
246
+ self.tie_weights()
247
+
248
+ return model_embeds
249
+
250
+
251
+ # Register the model so that it is available for transformer pipelines, auto-loading, etc.
252
+ AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
253
+
254
+
255
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
256
+ class OLMoForSequenceCLS(PreTrainedModel):
257
+ """
258
+ Extremely barebones HF model wrapper.
259
+ """
260
+
261
+ config_class = OLMoConfig
262
+ base_model_prefix = "model"
263
+ _no_split_modules = ["OLMoBlock"]
264
+
265
+ def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
266
+ super().__init__(config)
267
+ if not model:
268
+ model_config,num_labels = create_model_config_from_pretrained_config(config,is_cls=True)
269
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
270
+ model_config.init_device = "cpu"
271
+ self.model = OLMo(model_config, init_params=init_params)
272
+ else:
273
+ self.model = model
274
+ self.word_embeddings = self.model.transformer.wte
275
+ self.num_labels = num_labels
276
+ print(f"num_labels: {self.num_labels}")
277
+ self.score = torch.nn.Linear(config.hidden_size, self.num_labels, bias=False)
278
+
279
+
280
+ ###############
281
+ # mix resolution head
282
+ ################
283
+ # self.CNN = CNN_Head(output_size=self.num_labels,cnn_output_dim=config.hidden_size, kernel_sizes=[4,9],dropout_rate=0.11,
284
+ # num_cnn_layers=2)
285
+ def get_bidirectional_attention_bias(self, seq_len: int, device: torch.device):
286
+ """
287
+ Create a bidirectional attention bias for full sequence attention.
288
+ The bias matrix will not restrict attention in any direction.
289
+ """
290
+ # Bias shape: (1, 1, seq_len, seq_len)
291
+ bias = torch.zeros(1, 1, seq_len, seq_len, device=device)
292
+ return bias
293
+ def forward(
294
+ self,
295
+ input_ids: torch.LongTensor = None,
296
+ inputs_embeds: Optional[torch.FloatTensor] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ attention_bias: Optional[torch.Tensor] = None,
299
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
300
+ labels: Optional[torch.LongTensor] = None,
301
+ use_cache: Optional[bool] = None,
302
+ output_attentions: Optional[bool] = None,
303
+ output_hidden_states: Optional[bool] = None,
304
+ return_dict: Optional[bool] = None,
305
+ cache_position: Optional[
306
+ Cache
307
+ ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
308
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
309
+ if use_cache is None:
310
+ use_cache = self.config.use_cache
311
+
312
+ if output_attentions:
313
+ raise ValueError("output_attentions is not yet supported in OLMo")
314
+ ######
315
+ # Create attention bias only if it's not provided
316
+ ######
317
+ # if attention_bias is None:
318
+ # seq_len = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
319
+ # attention_bias = self.get_bidirectional_attention_bias(seq_len=seq_len, device=input_ids.device)
320
+ ######
321
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
322
+ ########
323
+ # The output_hidden_states flag is set as the output format of olmo is the following:
324
+ # return OLMoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None)
325
+ # so we have to forcely set the output hidden_states flag
326
+ ########
327
+ output_hidden_states = True
328
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
329
+ outputs = self.model.forward(
330
+ input_ids=input_ids,
331
+ input_embeddings=inputs_embeds,
332
+ attention_mask=attention_mask,
333
+ attention_bias=attention_bias,
334
+ past_key_values=past_key_values,
335
+ use_cache=use_cache,
336
+ output_hidden_states=output_hidden_states,
337
+ )
338
+ hidden_states = outputs.hidden_states[-1]
339
+ # assume that the padding is done by prepadding at the left of the input sequence
340
+ # the logit of the last non-padding token is logit[:,-1,:]
341
+ logits = self.score(hidden_states)
342
+ ##########
343
+ seq_lengths = attention_mask.sum(dim=-1)
344
+ # instead of taking the mean, we can also take the last token, taking the length of the sequence
345
+ pooled_logits = torch.stack(
346
+ [
347
+ logits[i, length - 1, :]
348
+ for i, length in enumerate(seq_lengths)
349
+ ],
350
+ dim=0,
351
+ )
352
+ ##########
353
+ loss = None
354
+ if labels is not None:
355
+ if self.config.problem_type is None:
356
+ if self.num_labels == 1:
357
+ self.config.problem_type = "regression"
358
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
359
+ self.config.problem_type = "single_label_classification"
360
+
361
+ if self.config.problem_type == "regression":
362
+ loss_fct = MSELoss()
363
+ if self.num_labels == 1:
364
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
365
+ else:
366
+ loss = loss_fct(pooled_logits, labels)
367
+ elif self.config.problem_type == "single_label_classification":
368
+ loss_fct = CrossEntropyLoss()
369
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
370
+
371
+ if not return_dict:
372
+ output = (pooled_logits,) + outputs[1:]
373
+ return ((loss,) + output) if loss is not None else output
374
+ return SequenceClassifierOutputWithPast(
375
+ loss=loss,
376
+ logits=pooled_logits,
377
+ past_key_values=outputs.attn_key_values,
378
+ hidden_states=hidden_states,
379
+ )
380
+ def forward_new(
381
+ self,
382
+ input_ids: torch.LongTensor = None,
383
+ inputs_embeds: Optional[torch.FloatTensor] = None,
384
+ attention_mask: Optional[torch.Tensor] = None,
385
+ attention_bias: Optional[torch.Tensor] = None,
386
+ onehot: Optional[torch.Tensor] = None, # New field
387
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
388
+ labels: Optional[torch.LongTensor] = None,
389
+ use_cache: Optional[bool] = None,
390
+ output_attentions: Optional[bool] = None,
391
+ output_hidden_states: Optional[bool] = None,
392
+ return_dict: Optional[bool] = None,
393
+ cache_position: Optional[
394
+ Cache
395
+ ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
396
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
397
+ if use_cache is None:
398
+ use_cache = self.config.use_cache
399
+
400
+ if output_attentions:
401
+ raise ValueError("output_attentions is not yet supported in OLMo")
402
+ ######
403
+ # input_ids shape
404
+ ######
405
+
406
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
407
+ ########
408
+ # The output_hidden_states flag is set as the output format of olmo is the following:
409
+ # return OLMoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None)
410
+ # so we have to forcely set the output hidden_states flag
411
+ ########
412
+ output_hidden_states = True
413
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
414
+ #----------
415
+ # outputs = self.model.forward(
416
+ # input_ids=input_ids,
417
+ # input_embeddings=inputs_embeds,
418
+ # attention_mask=attention_mask,
419
+ # attention_bias=attention_bias,
420
+ # past_key_values=past_key_values,
421
+ # use_cache=use_cache,
422
+ # output_hidden_states=output_hidden_states,
423
+ # )
424
+ # hidden_states = outputs.hidden_states[-1]
425
+ #-------------
426
+ # assume that the padding is done by prepadding at the left of the input sequence
427
+ # the logit of the last non-padding token is logit[:,-1,:]
428
+ # logits = self.score(hidden_states)
429
+ # pooled_logits = hidden_states[:,-1,:]
430
+ pooled_logits = self.CNN(onehot)
431
+
432
+ loss = None
433
+ if labels is not None:
434
+ if self.config.problem_type is None:
435
+ if self.num_labels == 1:
436
+ self.config.problem_type = "regression"
437
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
438
+ self.config.problem_type = "single_label_classification"
439
+
440
+ if self.config.problem_type == "regression":
441
+ loss_fct = MSELoss()
442
+ if self.num_labels == 1:
443
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
444
+ else:
445
+ loss = loss_fct(pooled_logits, labels)
446
+ elif self.config.problem_type == "single_label_classification":
447
+ loss_fct = CrossEntropyLoss()
448
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
449
+
450
+ # if not return_dict:
451
+ # output = (pooled_logits,) + outputs[1:] #------
452
+ # return ((loss,) + output) if loss is not None else output
453
+ return SequenceClassifierOutputWithPast(
454
+ loss=loss,
455
+ logits=pooled_logits,
456
+ # past_key_values=outputs.attn_key_values,
457
+ # hidden_states=hidden_states,
458
+ )
459
+
460
+ def can_generate(self) -> bool:
461
+ return True
462
+
463
+ def prepare_inputs_for_generation(
464
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
465
+ ):
466
+ if past_key_values:
467
+ # This is because we want the model to only process the last generated token.
468
+ input_ids = input_ids[:, -1:]
469
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
470
+
471
+ model_inputs.update(kwargs)
472
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
473
+ return model_inputs
474
+
475
+ # TODO: these are required to make the implementation complete.
476
+ # def resize_position_embeddings(self, new_num_position_embeddings: int):
477
+ # pass
478
+ #
479
+ # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
480
+ # pass
481
+ #
482
+ # def _reorder_cache(self, past_key_values, beam_idx):
483
+ # pass
484
+
485
+ def get_input_embeddings(self) -> torch.nn.Module:
486
+ return self.model.transformer.wte
487
+
488
+ def set_input_embeddings(self, value: torch.nn.Module):
489
+ self.model.transformer.wte = value
490
+
491
+ def get_output_embeddings(self):
492
+ if self.config.weight_tying:
493
+ return self.model.transformer.wte
494
+ else:
495
+ return self.model.transformer.ff_out
496
+
497
+ def set_output_embeddings(self, value: torch.nn.Module):
498
+ if self.config.weight_tying:
499
+ self.model.transformer.wte = value
500
+ else:
501
+ self.model.transformer.ff_out = value
502
+
503
+ def tie_weights(self):
504
+ """
505
+ This function is intentionally left as a no-op.
506
+
507
+ Weight tying is handled as follows:
508
+ - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
509
+ See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
510
+ - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
511
+ See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
512
+
513
+ Therefore, there is no need to explicitly tie the weights in this function.
514
+ """
515
+ pass
516
+
517
+ def resize_token_embeddings(
518
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
519
+ ) -> torch.nn.Embedding:
520
+ """
521
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.
522
+
523
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
524
+
525
+ Arguments:
526
+ new_num_tokens (`int`, *optional*):
527
+ The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
528
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
529
+ returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
530
+ pad_to_multiple_of (`int`, *optional*):
531
+ If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
532
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
533
+
534
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
535
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
536
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
537
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
538
+
539
+ Return:
540
+ `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
541
+
542
+ Note:
543
+ This method differs from the base class implementation by resizing the `embedding_size` attribute of the
544
+ model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
545
+ is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
546
+ embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
547
+ """
548
+ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
549
+ if new_num_tokens is None and pad_to_multiple_of is None:
550
+ return model_embeds
551
+
552
+ # Update base model and current model config
553
+ self.config.embedding_size = model_embeds.weight.shape[0]
554
+ self.model.config.embedding_size = model_embeds.weight.shape[0]
555
+
556
+ # Check if the embedding size is less than the vocab size
557
+ if self.config.embedding_size < self.config.vocab_size:
558
+ warning_message = (
559
+ f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
560
+ f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
561
+ "size is less than or equal to the new token embedding size."
562
+ )
563
+ log.warning(warning_message)
564
+
565
+ # Tie weights again if needed
566
+ self.tie_weights()
567
+
568
+ return model_embeds
569
+ # Register the model so that it is available for transformer pipelines, auto-loading, etc.
570
+ AutoModelForSequenceClassification.register(OLMoConfig, OLMoForSequenceCLS)
model/optim.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABCMeta, abstractmethod
3
+ from dataclasses import dataclass, replace
4
+ from math import cos, pi, sqrt
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
11
+ from torch.optim.optimizer import Optimizer as OptimizerBase
12
+
13
+ from . import LayerNormBase
14
+ from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
15
+ from .torch_util import get_default_device, is_distributed
16
+
17
+ __all__ = [
18
+ "Optimizer",
19
+ "LionW",
20
+ "AdamW",
21
+ "Scheduler",
22
+ "CosWithWarmup",
23
+ "LinearWithWarmup",
24
+ "InvSqrtWithWarmup",
25
+ "MaxScheduler",
26
+ "ConstantScheduler",
27
+ "BoltOnWarmupScheduler",
28
+ "build_optimizer",
29
+ "build_scheduler",
30
+ ]
31
+
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+
36
+ class Optimizer(OptimizerBase):
37
+ def _clean_param_name(self, name: str) -> str:
38
+ return name.replace("_fsdp_wrapped_module.", "")
39
+
40
+ @torch.no_grad()
41
+ def clip_grads_and_collect_metrics(
42
+ self, global_step: int, collect_param_metrics: bool = True
43
+ ) -> Dict[str, torch.Tensor]:
44
+ """
45
+ Clips gradients for every group that has the field `max_grad_norm`.
46
+ At the same time collect metrics for each parameter and its gradient.
47
+ """
48
+ device = get_default_device()
49
+
50
+ # NOTE (epwalsh): during distributed training we're making an assumption that the order of
51
+ # the param groups and the params within each group are the same across all ranks.
52
+ # This is justified since we initialize the parameter groups in every rank by iterating over
53
+ # `module.parameters()` or `module.named_modules()` / `module.named_parameters()`, each of which
54
+ # provides a consistent order.
55
+ # For each parameter (with a gradient) we'll collect:
56
+ # - min, max, avg, norm of the param itself
57
+ # - min, max, avg, norm of the param's gradient
58
+ # - min, max, avg, norm of any additional per-parameter optimizer state metrics returned from
59
+ # `self.get_state_for_param()`.
60
+ # Afterwards we'll reduce these all over all ranks.
61
+ per_param_min_metrics: List[torch.Tensor] = []
62
+ per_param_max_metrics: List[torch.Tensor] = []
63
+ per_param_sum_metrics: List[torch.Tensor] = []
64
+ per_param_norm_metrics: List[torch.Tensor] = []
65
+ per_param_numel_metrics: List[torch.Tensor] = []
66
+
67
+ per_param_min_metric_names: List[str] = []
68
+ per_param_max_metric_names: List[str] = []
69
+ per_param_avg_metric_names: List[str] = []
70
+ per_param_norm_metric_names: List[str] = []
71
+
72
+ # Collect metrics locally.
73
+ for group in self.param_groups:
74
+ if is_distributed():
75
+ # TODO (epwalsh): handle non-sharded params. We don't have any right now but we would
76
+ # with ReLoRa, for example.
77
+ assert group.get("sharded", True) is True
78
+
79
+ for name, p in zip(group["param_names"], group["params"]):
80
+ name = self._clean_param_name(name)
81
+ # Always need to collect the norm of gradients for clipping, even if we're not collecting
82
+ # other metrics.
83
+ tensors: List[Optional[torch.Tensor]] = [p.grad]
84
+ prefixes: List[str] = [f"grad/{name}"]
85
+ if collect_param_metrics:
86
+ state = self.get_state_for_param(p)
87
+ sorted_state_keys = sorted([k for k in state.keys()])
88
+ tensors.extend([p] + [state[key] for key in sorted_state_keys])
89
+ prefixes.extend([f"param/{name}"] + [f"{key}/{name}" for key in sorted_state_keys])
90
+ assert len(tensors) == len(prefixes)
91
+
92
+ # Get min, max, avg, and norm for all `tensors` associated with the parameter.
93
+ for x, prefix in zip(tensors, prefixes):
94
+ # grad or state tensors could be none for params that have their shards completely on
95
+ # other ranks.
96
+ if x is not None and x.numel() > 0:
97
+ if collect_param_metrics:
98
+ x_abs = x.abs()
99
+ per_param_min_metrics.append(x_abs.min().unsqueeze(0).to(dtype=torch.float32))
100
+ per_param_max_metrics.append(x_abs.max().unsqueeze(0).to(dtype=torch.float32))
101
+ per_param_sum_metrics.append(x.sum().unsqueeze(0).to(dtype=torch.float32))
102
+ per_param_numel_metrics.append(
103
+ torch.tensor([x.numel()], device=device, dtype=torch.float32)
104
+ )
105
+ per_param_norm_metrics.append(
106
+ torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0)
107
+ )
108
+ else:
109
+ if collect_param_metrics:
110
+ per_param_min_metrics.append(
111
+ torch.tensor([float("inf")], device=device, dtype=torch.float32)
112
+ )
113
+ per_param_max_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
114
+ per_param_sum_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
115
+ per_param_numel_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
116
+ per_param_norm_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
117
+ if collect_param_metrics:
118
+ per_param_min_metric_names.append(f"{prefix}.min")
119
+ per_param_max_metric_names.append(f"{prefix}.max")
120
+ per_param_avg_metric_names.append(f"{prefix}.avg")
121
+ per_param_norm_metric_names.append(f"{prefix}.norm")
122
+
123
+ assert (
124
+ len(per_param_min_metrics)
125
+ == len(per_param_min_metric_names)
126
+ == len(per_param_max_metrics)
127
+ == len(per_param_max_metric_names)
128
+ == len(per_param_sum_metrics)
129
+ == len(per_param_numel_metrics)
130
+ == len(per_param_avg_metric_names)
131
+ )
132
+ assert len(per_param_norm_metrics) == len(per_param_norm_metric_names)
133
+
134
+ def is_grad_norm_metric(metric_name: str) -> bool:
135
+ return metric_name.startswith("grad/") and metric_name.endswith(".norm")
136
+
137
+ # Now reduce metrics over all ranks.
138
+ total_grad_norm: torch.Tensor
139
+ per_param_avg_metrics: List[torch.Tensor] = []
140
+ if is_distributed(): # TODO (epwalsh): skip for non-sharded params
141
+ # Reduce metrics across all ranks. Note that we can use a `reduce` for most cases
142
+ # instead of an `all_reduce`, but we need `all_reduce` for norms so that all ranks
143
+ # get the right value for gradient norms so they can clip correctly.
144
+ # Reduce mins.
145
+ if per_param_min_metrics:
146
+ all_mins = torch.cat(per_param_min_metrics).to(device)
147
+ dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
148
+ per_param_min_metrics = all_mins.split(1)
149
+ # Reduce maxs.
150
+ if per_param_max_metrics:
151
+ all_maxs = torch.cat(per_param_max_metrics).to(device)
152
+ dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX)
153
+ per_param_max_metrics = all_maxs.split(1)
154
+ # Reduce sums or just norms.
155
+ all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
156
+ if per_param_sum_metrics and per_param_numel_metrics:
157
+ all_sums = torch.cat(per_param_sum_metrics).to(device)
158
+ all_numels = torch.cat(per_param_numel_metrics).to(device)
159
+ all_sums_norms_numels = torch.cat(
160
+ [all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
161
+ )
162
+ dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM)
163
+ all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
164
+ # Get averages.
165
+ # NOTE: could get infs for non-rank0 processes but that's okay.
166
+ per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
167
+ else:
168
+ dist.all_reduce(all_norms, op=dist.ReduceOp.SUM)
169
+ grad_norm_metric_mask = torch.tensor(
170
+ [float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
171
+ )
172
+ total_grad_norm = (all_norms * grad_norm_metric_mask).sum() ** 0.5
173
+ per_param_norm_metrics = (all_norms ** (0.5)).squeeze(0).split(1)
174
+ else:
175
+ total_grad_norm = (
176
+ torch.cat(
177
+ [
178
+ m
179
+ for m, n in zip(per_param_norm_metrics, per_param_norm_metric_names)
180
+ if is_grad_norm_metric(n)
181
+ ]
182
+ )
183
+ ** 2.0
184
+ ).sum() ** 0.5
185
+ per_param_avg_metrics = [x / n for x, n in zip(per_param_sum_metrics, per_param_numel_metrics)]
186
+
187
+ assert len(per_param_avg_metrics) == len(per_param_avg_metric_names)
188
+
189
+ # Collect all metrics into a single dict.
190
+ all_metrics: Dict[str, torch.Tensor] = {}
191
+ for metric_name, metric in zip(per_param_min_metric_names, per_param_min_metrics):
192
+ all_metrics[metric_name] = metric.squeeze(0)
193
+ for metric_name, metric in zip(per_param_max_metric_names, per_param_max_metrics):
194
+ all_metrics[metric_name] = metric.squeeze(0)
195
+ for metric_name, metric in zip(per_param_avg_metric_names, per_param_avg_metrics):
196
+ all_metrics[metric_name] = metric.squeeze(0)
197
+ for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics):
198
+ all_metrics[metric_name] = metric.squeeze(0)
199
+ all_metrics["total_grad_norm"] = total_grad_norm
200
+
201
+ # Clip gradients.
202
+ num_grads_clipped = 0
203
+ num_eligible_grads = 0
204
+ for group in self.param_groups:
205
+ if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None:
206
+ num_clipped = self._do_adaptive_clipping(
207
+ group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics
208
+ )
209
+ elif (max_norm := group.get("max_grad_norm")) is not None:
210
+ num_clipped = self._do_global_fixed_clipping(
211
+ group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics
212
+ )
213
+ else:
214
+ # No clipping needed.
215
+ continue
216
+ num_eligible_grads += len(group["params"])
217
+ if num_clipped is not None:
218
+ num_grads_clipped += num_clipped
219
+
220
+ if collect_param_metrics:
221
+ if num_eligible_grads > 0:
222
+ clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
223
+ else:
224
+ clipping_rate = torch.tensor(0.0, device="cpu")
225
+ all_metrics["clipping_rate"] = clipping_rate
226
+ return all_metrics
227
+ else:
228
+ return {}
229
+
230
+ @torch.no_grad()
231
+ def _do_adaptive_clipping(
232
+ self,
233
+ group: Dict[str, Any],
234
+ max_norm_ratio: float,
235
+ global_step: int,
236
+ all_metrics: Dict[str, torch.Tensor],
237
+ collect_param_metrics: bool = True,
238
+ ) -> Optional[int]:
239
+ """
240
+ Do adaptive gradient clipping on a param group.
241
+
242
+ If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
243
+ """
244
+ device = get_default_device()
245
+ num_grads_clipped = 0
246
+ # We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of
247
+ # the gradient (a scalar), not to be confused with the exponential average of the gradient.
248
+ # TODO (epwalsh): handle optimizers that don't have betas.
249
+ beta1, beta2 = group["betas"]
250
+ beta = max(beta1, beta2)
251
+ for name, p in zip(group["param_names"], group["params"]):
252
+ name = self._clean_param_name(name)
253
+ grad_norm = all_metrics.get(f"grad/{name}.norm")
254
+ if grad_norm is None:
255
+ continue
256
+
257
+ # Get or initialize the exponential average of grad norm.
258
+ # TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter,
259
+ # even parameters for which the corresponding local shard is empty. This has the potential to
260
+ # cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372.
261
+ # So we should consider changing how we do this at some point so that we don't add any state
262
+ # to parameters for which the local shard is empty. That would probably add extra distributed
263
+ # communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`).
264
+ state = self.state[p]
265
+ grad_norm_exp_avg = state.get("grad_norm_exp_avg")
266
+ if grad_norm_exp_avg is None:
267
+ grad_norm_exp_avg = grad_norm.clone().to(device)
268
+ # We don't want to add anything to `state` until `state` has been initialized, otherwise
269
+ # this will crash some optimizers which rely on checking `len(state)`. The downside here
270
+ # is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step.
271
+ if global_step > 1:
272
+ state["grad_norm_exp_avg"] = grad_norm_exp_avg
273
+
274
+ max_allowed_norm = max_norm_ratio * grad_norm_exp_avg
275
+ clip_coef = max_allowed_norm / (grad_norm + 1e-6)
276
+
277
+ # Clip the gradients and update the exponential average.
278
+ # Note that multiplying by the clamped coefficient is meaningless when it is
279
+ # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
280
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
281
+ if p.grad is not None:
282
+ # p.grad could be none for some ranks when using FSDP.
283
+ p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
284
+
285
+ # Update the exponential average of the norm of the gradient with the clipped norm of the gradient.
286
+ grad_norm_exp_avg.lerp_((grad_norm * clip_coef_clamped).to(grad_norm_exp_avg.device), 1 - beta)
287
+ # Alternative: update with the *unclipped* norm of the gradient.
288
+ # grad_norm_exp_avg.lerp_(grad_norm.to(grad_norm_exp_avg.device), 1 - beta)
289
+
290
+ if collect_param_metrics:
291
+ # Can't avoid host-device sync here.
292
+ if clip_coef_clamped < 1.0:
293
+ num_grads_clipped += 1
294
+ all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg
295
+ return num_grads_clipped if collect_param_metrics else None
296
+
297
+ @torch.no_grad()
298
+ def _do_global_fixed_clipping(
299
+ self,
300
+ group: Dict[str, Any],
301
+ max_norm: float,
302
+ all_metrics: Dict[str, torch.Tensor],
303
+ collect_param_metrics: bool = True,
304
+ ) -> Optional[int]:
305
+ """
306
+ Do global fixed gradient clipping on a param group.
307
+
308
+ If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
309
+ """
310
+ device = get_default_device()
311
+ total_grad_norm = all_metrics["total_grad_norm"]
312
+ clip_coef = max_norm / (total_grad_norm.to(device) + 1e-6)
313
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
314
+ num_grads_clipped: Optional[int] = None
315
+ if collect_param_metrics:
316
+ # Can't avoid host-device sync here.
317
+ if clip_coef_clamped < 1.0:
318
+ num_grads_clipped = len(group["params"])
319
+ for p in group["params"]:
320
+ # Clip the gradients.
321
+ # Note that multiplying by the clamped coefficient is meaningless when it is
322
+ # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
323
+ if p.grad is not None:
324
+ # p.grad could be none for some ranks when using FSDP.
325
+ p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
326
+ return num_grads_clipped
327
+
328
+ def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
329
+ del module
330
+ return {}
331
+
332
+ def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
333
+ del param
334
+ return {}
335
+
336
+
337
+ class LionW(Optimizer):
338
+ """
339
+ Adapted from https://github.com/google/automl/blob/master/lion/lion_pytorch.py
340
+ """
341
+
342
+ def __init__(
343
+ self,
344
+ params,
345
+ lr: float = 1e-4,
346
+ betas: Tuple[float, float] = (0.9, 0.99),
347
+ weight_decay: float = 0.0,
348
+ ):
349
+ assert lr > 0.0
350
+ assert all([0.0 <= beta <= 1.0 for beta in betas])
351
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
352
+ super().__init__(params, defaults)
353
+ for group in self.param_groups:
354
+ group["initial_lr"] = group["lr"]
355
+ self._update_total_dot_prod: Optional[torch.Tensor] = None
356
+ self._update_total_norm: Optional[torch.Tensor] = None
357
+ self._signed_update_total_norm: Optional[torch.Tensor] = None
358
+
359
+ def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
360
+ update_total_dot_prod = self._update_total_dot_prod
361
+ update_total_norm = self._update_total_norm
362
+ signed_update_total_norm = self._signed_update_total_norm
363
+ if update_total_dot_prod is None or update_total_norm is None or signed_update_total_norm is None:
364
+ return {}
365
+
366
+ if is_distributed() and isinstance(module, FullyShardedDataParallel):
367
+ # Reduce total dot prod and norms across all ranks.
368
+ update_total_norm = update_total_norm**2.0
369
+ signed_update_total_norm = signed_update_total_norm**2.0
370
+ # Reduce all together to avoid multiple communication calls.
371
+ all_together = torch.stack([update_total_dot_prod, update_total_norm, signed_update_total_norm])
372
+ # Only need the final result on rank0, since that's where we log from.
373
+ dist.reduce(all_together, 0)
374
+ update_total_dot_prod, update_total_norm, signed_update_total_norm = all_together
375
+ update_total_norm = update_total_norm**0.5
376
+ signed_update_total_norm = signed_update_total_norm**0.5
377
+
378
+ update_cos_sim = update_total_dot_prod / torch.max(
379
+ update_total_norm * signed_update_total_norm, torch.tensor(1e-8, device=get_default_device())
380
+ )
381
+ return {"update_cos_sim": update_cos_sim}
382
+
383
+ @torch.no_grad()
384
+ def step(self, closure=None) -> None:
385
+ if closure is not None:
386
+ with torch.enable_grad():
387
+ closure()
388
+
389
+ update_total_dot_prod = torch.tensor(0.0, dtype=torch.float32)
390
+ update_norms = []
391
+ signed_update_norms = []
392
+
393
+ for group in self.param_groups:
394
+ for p in group["params"]:
395
+ if p.grad is None:
396
+ continue
397
+
398
+ # Perform step weight decay
399
+ p.data.mul_(1 - group["lr"] * group["weight_decay"])
400
+
401
+ grad = p.grad
402
+ state = self.state[p]
403
+
404
+ # State initialization
405
+ if len(state) == 0:
406
+ # Exponential moving average of gradient values
407
+ state["exp_avg"] = torch.zeros_like(p)
408
+
409
+ exp_avg = state["exp_avg"]
410
+ beta1, beta2 = group["betas"]
411
+
412
+ # Weight update
413
+ update = exp_avg * beta1 + grad * (1 - beta1)
414
+ signed_update = torch.sign(update)
415
+ p.add_(signed_update, alpha=-group["lr"])
416
+
417
+ # Decay the momentum running average coefficient
418
+ exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
419
+
420
+ # Track dot product and norms of update vs signed update in order to calculate
421
+ # their cosine similarity.
422
+ update_total_dot_prod = update_total_dot_prod.to(update.device)
423
+ update_total_dot_prod += torch.tensordot(update, signed_update, dims=len(update.shape))
424
+ update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32))
425
+ signed_update_norms.append(torch.linalg.vector_norm(signed_update, 2.0, dtype=torch.float32))
426
+
427
+ # Compute cosine similarity between update and signed update.
428
+ self._update_total_dot_prod = update_total_dot_prod.to(get_default_device())
429
+ self._update_total_norm = torch.linalg.vector_norm(
430
+ torch.stack(update_norms),
431
+ 2.0,
432
+ dtype=torch.float32,
433
+ ).to(get_default_device())
434
+ self._signed_update_total_norm = torch.linalg.vector_norm(
435
+ torch.stack(signed_update_norms),
436
+ 2.0,
437
+ dtype=torch.float32,
438
+ ).to(get_default_device())
439
+
440
+
441
+ class AdamW(torch.optim.AdamW, Optimizer):
442
+ def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
443
+ return {key: self.state[param].get(key) for key in ("exp_avg", "exp_avg_sq")} # type: ignore
444
+
445
+
446
+ @dataclass
447
+ class Scheduler(metaclass=ABCMeta):
448
+ # NOTE: these fields are not given default values because otherwise dataclasses complains
449
+ # about how the scheduler subclasses are defined.
450
+ grad_clip_warmup_steps: Optional[int]
451
+ grad_clip_warmup_factor: Optional[float]
452
+ warmup_min_lr: Optional[float]
453
+
454
+ @abstractmethod
455
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
456
+ raise NotImplementedError
457
+
458
+ def _get_max_grad_norm_coeff(
459
+ self, initial_value: Optional[float], step: int, max_steps: int
460
+ ) -> Optional[float]:
461
+ del max_steps # might need this in the future, but for now I just wanted to match the API of `get_lr()`.
462
+ if initial_value is None:
463
+ return None
464
+ elif (
465
+ self.grad_clip_warmup_steps is None
466
+ or self.grad_clip_warmup_factor is None
467
+ or step > self.grad_clip_warmup_steps
468
+ ):
469
+ return initial_value
470
+ else:
471
+ return self.grad_clip_warmup_factor * initial_value
472
+
473
+ def get_max_grad_norm(
474
+ self, initial_max_grad_norm: Optional[float], step: int, max_steps: int
475
+ ) -> Optional[float]:
476
+ return self._get_max_grad_norm_coeff(initial_max_grad_norm, step, max_steps)
477
+
478
+ def get_max_grad_norm_ratio(
479
+ self, initial_max_grad_norm_ratio: Optional[float], step: int, max_steps: int
480
+ ) -> Optional[float]:
481
+ return self._get_max_grad_norm_coeff(initial_max_grad_norm_ratio, step, max_steps)
482
+
483
+ def _linear_warmup(self, initial_lr: float, step: int, warmup_steps: int = 2000) -> float:
484
+ warmup_min_lr = self.warmup_min_lr if self.warmup_min_lr is not None else initial_lr * 0.10
485
+ assert 0 <= warmup_min_lr < initial_lr
486
+ return warmup_min_lr + (initial_lr - warmup_min_lr) * min(step, warmup_steps) / warmup_steps
487
+
488
+
489
+ @dataclass
490
+ class CosWithWarmup(Scheduler):
491
+ warmup_steps: int
492
+ alpha_f: float = 0.1
493
+ t_max: Optional[int] = None
494
+
495
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
496
+ max_steps = max_steps if self.t_max is None else self.t_max
497
+ eta_min = initial_lr * self.alpha_f
498
+ if step < self.warmup_steps:
499
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
500
+ elif step >= max_steps:
501
+ return eta_min
502
+ else:
503
+ step = step - self.warmup_steps
504
+ max_steps = max_steps - self.warmup_steps
505
+ return eta_min + (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
506
+
507
+
508
+ @dataclass
509
+ class LinearWithWarmup(Scheduler):
510
+ warmup_steps: int
511
+ alpha_f: float = 0.1
512
+ t_max: Optional[int] = None
513
+
514
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
515
+ max_steps = max_steps if self.t_max is None else self.t_max
516
+ eta_min = initial_lr * self.alpha_f
517
+ if step < self.warmup_steps:
518
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
519
+ elif step >= max_steps:
520
+ return eta_min
521
+ else:
522
+ step = step - self.warmup_steps
523
+ max_steps = max_steps - self.warmup_steps
524
+ return initial_lr - (initial_lr - eta_min) * (step / max_steps)
525
+
526
+
527
+ @dataclass
528
+ class InvSqrtWithWarmup(Scheduler):
529
+ warmup_steps: int
530
+
531
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
532
+ if step < self.warmup_steps:
533
+ return self._linear_warmup(initial_lr, step, self.warmup_steps)
534
+ del max_steps
535
+ return initial_lr * sqrt(self.warmup_steps / max(self.warmup_steps, step))
536
+
537
+
538
+ @dataclass
539
+ class MaxScheduler(Scheduler):
540
+ sched1: Scheduler
541
+ sched2: Scheduler
542
+
543
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
544
+ return max(
545
+ self.sched1.get_lr(initial_lr, step, max_steps), self.sched2.get_lr(initial_lr, step, max_steps)
546
+ )
547
+
548
+
549
+ @dataclass
550
+ class BoltOnWarmupScheduler(Scheduler):
551
+ inner: Scheduler
552
+ warmup_start: int
553
+ warmup_end: int
554
+
555
+ @classmethod
556
+ def wrap(cls, scheduler: Scheduler, warmup_start: int, warmup_end: int) -> "BoltOnWarmupScheduler":
557
+ return cls(
558
+ grad_clip_warmup_steps=None,
559
+ grad_clip_warmup_factor=None,
560
+ inner=scheduler,
561
+ warmup_start=warmup_start,
562
+ warmup_end=warmup_end,
563
+ warmup_min_lr=None,
564
+ )
565
+
566
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
567
+ if step < self.warmup_start:
568
+ return 0.0
569
+ if step < self.warmup_end:
570
+ lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps)
571
+ return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start)
572
+ else:
573
+ return self.inner.get_lr(initial_lr, step, max_steps)
574
+
575
+ def _get_max_grad_norm_coeff(
576
+ self, initial_value: Optional[float], step: int, max_steps: int
577
+ ) -> Optional[float]:
578
+ return self.inner._get_max_grad_norm_coeff(initial_value, step, max_steps)
579
+
580
+
581
+ @dataclass
582
+ class ConstantScheduler(Scheduler):
583
+ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
584
+ del step, max_steps
585
+ return initial_lr
586
+
587
+
588
+ PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")
589
+
590
+
591
+ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]:
592
+ """
593
+ Separate parameters into weight decay and non weight decay groups.
594
+ """
595
+ param_groups: List[Dict[str, Any]]
596
+ param_group_defaults = {
597
+ "sharded": isinstance(model, FullyShardedDataParallel),
598
+ "max_grad_norm": cfg.max_grad_norm,
599
+ "max_grad_norm_ratio": cfg.max_grad_norm_ratio,
600
+ }
601
+
602
+ # Separate out parameters that we don't want to apply weight decay to, like norms and biases.
603
+ decay = set()
604
+ no_decay = set()
605
+ all_params = {}
606
+ for mn, m in model.named_modules():
607
+ for pn, p in m.named_parameters():
608
+ # NOTE: because named_modules and named_parameters are recursive
609
+ # we will see the same tensors p many many times, but doing it this way
610
+ # allows us to know which parent module any tensor p belongs to...
611
+ if not p.requires_grad:
612
+ continue
613
+
614
+ fpn = f"{mn}.{pn}" if mn else pn
615
+ all_params[fpn] = p
616
+
617
+ if pn.endswith("bias"):
618
+ if cfg.optimizer.decay_norm_and_bias:
619
+ decay.add(fpn)
620
+ else:
621
+ no_decay.add(fpn)
622
+ elif pn.endswith("weight") and isinstance(m, nn.Linear):
623
+ decay.add(fpn)
624
+ elif pn.endswith("weight") and isinstance(m, (LayerNormBase, nn.LayerNorm)):
625
+ if cfg.optimizer.decay_norm_and_bias:
626
+ decay.add(fpn)
627
+ else:
628
+ no_decay.add(fpn)
629
+ elif pn.endswith("weight") and isinstance(m, nn.Embedding):
630
+ if cfg.optimizer.decay_embeddings:
631
+ decay.add(fpn)
632
+ else:
633
+ no_decay.add(fpn)
634
+
635
+ # Validate that we've considered every parameter
636
+ inter_params = decay & no_decay
637
+ union_params = decay | no_decay
638
+ assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!"
639
+ assert (
640
+ len(all_params.keys() - union_params) == 0
641
+ ), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!"
642
+
643
+ # Create the pytorch optimizer groups.
644
+ decay_sorted = sorted(list(decay))
645
+ no_decay_sorted = sorted(list(no_decay))
646
+ param_groups = []
647
+ if len(decay_sorted) > 0:
648
+ param_groups.append(
649
+ {
650
+ "params": [all_params[pn] for pn in decay_sorted],
651
+ "param_names": decay_sorted,
652
+ **param_group_defaults,
653
+ }
654
+ )
655
+ if len(no_decay_sorted) > 0:
656
+ param_groups.append(
657
+ {
658
+ "params": [all_params[pn] for pn in no_decay_sorted],
659
+ "param_names": no_decay_sorted,
660
+ "weight_decay": 0.0,
661
+ **param_group_defaults,
662
+ }
663
+ )
664
+
665
+ # Validate fields.
666
+ for group in param_groups:
667
+ for key in PARAM_GROUP_FIELDS:
668
+ assert key in group
669
+
670
+ return param_groups
671
+
672
+
673
+ def fix_optim_state_dict(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
674
+ """
675
+ Make sure old optim state dicts are compatible with new versions.
676
+ """
677
+ if len(state_dict["param_groups"]) == 1 and len(optimizer.param_groups) == 2:
678
+ assert optimizer.param_groups[1]["weight_decay"] == 0.0
679
+
680
+ # Decay
681
+ decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
682
+ decay_param_group["params"] = optimizer.state_dict()["param_groups"][0]["params"]
683
+
684
+ # No decay.
685
+ no_decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
686
+ no_decay_param_group["weight_decay"] = 0.0
687
+ no_decay_param_group["params"] = optimizer.state_dict()["param_groups"][1]["params"]
688
+
689
+ state_dict["param_groups"] = [decay_param_group, no_decay_param_group]
690
+
691
+ assert len(optimizer.param_groups) == len(state_dict["param_groups"])
692
+
693
+ # Make sure:
694
+ # - All required fields are included in the state dict,
695
+ # - And that the values of those fields doesn't change from what's currently set in the optimizer,
696
+ # since we might have changed those fields on purpose after a restart.
697
+ for group, sd_group in zip(optimizer.param_groups, state_dict["param_groups"]):
698
+ for key in PARAM_GROUP_FIELDS:
699
+ sd_group[key] = group[key]
700
+
701
+ return state_dict
702
+
703
+
704
+ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
705
+ param_groups = get_param_groups(cfg, model)
706
+ log.info(f"Constructing optimizer with {len(param_groups)} param groups")
707
+ if cfg.optimizer.name == OptimizerType.lionw:
708
+ return LionW(
709
+ param_groups,
710
+ lr=cfg.optimizer.learning_rate,
711
+ betas=cfg.optimizer.betas,
712
+ weight_decay=cfg.optimizer.weight_decay,
713
+ )
714
+ elif cfg.optimizer.name == OptimizerType.adamw:
715
+ return AdamW(
716
+ param_groups,
717
+ lr=cfg.optimizer.learning_rate,
718
+ betas=cfg.optimizer.betas,
719
+ weight_decay=cfg.optimizer.weight_decay,
720
+ eps=1e-5,
721
+ )
722
+ else:
723
+ raise NotImplementedError
724
+
725
+
726
+ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = None) -> Scheduler:
727
+ sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler
728
+ if sched_cfg.name == SchedulerType.cosine_with_warmup:
729
+ return CosWithWarmup(
730
+ grad_clip_warmup_steps=None
731
+ if sched_cfg.grad_clip_warmup_steps is None
732
+ else int(sched_cfg.grad_clip_warmup_steps),
733
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
734
+ warmup_steps=int(sched_cfg.t_warmup),
735
+ alpha_f=sched_cfg.alpha_f,
736
+ t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
737
+ warmup_min_lr=sched_cfg.warmup_min_lr,
738
+ )
739
+ elif sched_cfg.name == SchedulerType.linear_with_warmup:
740
+ return LinearWithWarmup(
741
+ grad_clip_warmup_steps=None
742
+ if sched_cfg.grad_clip_warmup_steps is None
743
+ else int(sched_cfg.grad_clip_warmup_steps),
744
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
745
+ warmup_steps=int(sched_cfg.t_warmup),
746
+ alpha_f=sched_cfg.alpha_f,
747
+ t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
748
+ warmup_min_lr=sched_cfg.warmup_min_lr,
749
+ )
750
+ elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup:
751
+ return InvSqrtWithWarmup(
752
+ grad_clip_warmup_steps=None
753
+ if sched_cfg.grad_clip_warmup_steps is None
754
+ else int(sched_cfg.grad_clip_warmup_steps),
755
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
756
+ warmup_steps=int(sched_cfg.t_warmup),
757
+ warmup_min_lr=sched_cfg.warmup_min_lr,
758
+ )
759
+ elif sched_cfg.name == SchedulerType.max_scheduler:
760
+ return MaxScheduler(
761
+ grad_clip_warmup_steps=None
762
+ if sched_cfg.grad_clip_warmup_steps is None
763
+ else int(sched_cfg.grad_clip_warmup_steps),
764
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
765
+ sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)),
766
+ sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)),
767
+ warmup_min_lr=sched_cfg.warmup_min_lr,
768
+ )
769
+ elif sched_cfg.name == SchedulerType.constant:
770
+ return ConstantScheduler(
771
+ grad_clip_warmup_steps=None
772
+ if sched_cfg.grad_clip_warmup_steps is None
773
+ else int(sched_cfg.grad_clip_warmup_steps),
774
+ grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
775
+ warmup_min_lr=sched_cfg.warmup_min_lr,
776
+ )
777
+ else:
778
+ raise NotImplementedError
model/py.typed ADDED
File without changes
model/safetensors_util.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import pickle
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Optional, Tuple
5
+
6
+ import safetensors.torch
7
+ import torch
8
+
9
+ from olmo.aliases import PathOrStr
10
+
11
+ __all__ = [
12
+ "state_dict_to_safetensors_file",
13
+ "safetensors_file_to_state_dict",
14
+ ]
15
+
16
+
17
+ @dataclass(eq=True, frozen=True)
18
+ class STKey:
19
+ keys: Tuple
20
+ value_is_pickled: bool
21
+
22
+
23
+ def encode_key(key: STKey) -> str:
24
+ b = pickle.dumps((key.keys, key.value_is_pickled))
25
+ b = base64.urlsafe_b64encode(b)
26
+ return str(b, "ASCII")
27
+
28
+
29
+ def decode_key(key: str) -> STKey:
30
+ b = base64.urlsafe_b64decode(key)
31
+ keys, value_is_pickled = pickle.loads(b)
32
+ return STKey(keys, value_is_pickled)
33
+
34
+
35
+ def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
36
+ result = {}
37
+ for key, value in d.items():
38
+ if isinstance(value, torch.Tensor):
39
+ result[STKey((key,), False)] = value
40
+ elif isinstance(value, dict):
41
+ value = flatten_dict(value)
42
+ for inner_key, inner_value in value.items():
43
+ result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
44
+ else:
45
+ pickled = bytearray(pickle.dumps(value))
46
+ pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
47
+ result[STKey((key,), True)] = pickled_tensor
48
+ return result
49
+
50
+
51
+ def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
52
+ result: Dict = {}
53
+
54
+ for key, value in d.items():
55
+ if key.value_is_pickled:
56
+ value = pickle.loads(value.numpy().data)
57
+
58
+ target_dict = result
59
+ for k in key.keys[:-1]:
60
+ new_target_dict = target_dict.get(k)
61
+ if new_target_dict is None:
62
+ new_target_dict = {}
63
+ target_dict[k] = new_target_dict
64
+ target_dict = new_target_dict
65
+ target_dict[key.keys[-1]] = value
66
+
67
+ return result
68
+
69
+
70
+ def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
71
+ state_dict = flatten_dict(state_dict)
72
+ state_dict = {encode_key(k): v for k, v in state_dict.items()}
73
+ safetensors.torch.save_file(state_dict, filename)
74
+
75
+
76
+ def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
77
+ if map_location is None:
78
+ map_location = "cpu"
79
+ state_dict = safetensors.torch.load_file(filename, device=map_location)
80
+ state_dict = {decode_key(k): v for k, v in state_dict.items()}
81
+ return unflatten_dict(state_dict)
model/tokenization_olmo_fast.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, PreTrainedTokenizerFast
2
+
3
+ from .configuration_olmo import OLMoConfig
4
+
5
+
6
+ class OLMoTokenizerFast(PreTrainedTokenizerFast):
7
+ # Note: OLMo's tokenizer is already a wrapper around huggingface. This is potentially unnecessary.
8
+ pass
9
+
10
+ # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
11
+ # # This is required to make the implementation complete.
12
+ # pass
13
+
14
+
15
+ # Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc.
16
+ AutoTokenizer.register(OLMoConfig, fast_tokenizer_class=OLMoTokenizerFast)
model/tokenizer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import List, Optional, Union
6
+
7
+ from tokenizers import Tokenizer as BaseTokenizer
8
+
9
+ from .aliases import PathOrStr
10
+ from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection
11
+ from .exceptions import OLMoConfigurationError
12
+
13
+ __all__ = ["Tokenizer"]
14
+
15
+
16
+ class Tokenizer:
17
+ """
18
+ A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`.
19
+
20
+ :param base_tokenizer: The :class:`tokenizers.Tokenizer` to use.
21
+ :param eos_token_id: The token ID corresponding to the "end-of-sentence" token.
22
+ :param truncate_to: Truncate when tokenizing to this number of token IDs.
23
+ :param truncate_direction: The direction to truncate in. "right" means truncate the tokens
24
+ on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null,
25
+ this setting has no effect.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ base_tokenizer: BaseTokenizer,
31
+ eos_token_id: int,
32
+ pad_token_id: Optional[int] = None,
33
+ truncate_to: Optional[int] = None,
34
+ truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right,
35
+ ):
36
+ self.base_tokenizer = base_tokenizer
37
+ self.base_tokenizer.no_truncation()
38
+ self.eos_token_id = eos_token_id
39
+ self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id
40
+ self.truncate_to = truncate_to
41
+ self.truncate_direction = TruncationDirection(truncate_direction)
42
+
43
+ @property
44
+ def vocab_size(self) -> int:
45
+ return self.base_tokenizer.get_vocab_size()
46
+
47
+ @property
48
+ def eos_token(self) -> str:
49
+ return self.decode([self.eos_token_id], skip_special_tokens=False)
50
+
51
+ @property
52
+ def pad_token(self) -> str:
53
+ return self.decode([self.pad_token_id], skip_special_tokens=False)
54
+
55
+ @classmethod
56
+ def from_train_config(cls, config: TrainConfig) -> Tokenizer:
57
+ tokenizer_identifier = config.tokenizer.identifier
58
+ if Path(tokenizer_identifier).is_file():
59
+ tokenizer = cls.from_file(
60
+ tokenizer_identifier,
61
+ eos_token_id=config.model.eos_token_id,
62
+ pad_token_id=config.model.pad_token_id,
63
+ )
64
+ else:
65
+ tokenizer = cls.from_pretrained(
66
+ tokenizer_identifier,
67
+ eos_token_id=config.model.eos_token_id,
68
+ pad_token_id=config.model.pad_token_id,
69
+ )
70
+ if config.model.vocab_size != tokenizer.vocab_size:
71
+ raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
72
+ return tokenizer
73
+
74
+ @classmethod
75
+ def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer:
76
+ """
77
+ Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub.
78
+
79
+ :param identifier: The identifier of a model on the Hub that contains a
80
+ ``tokenizer.json`` file.
81
+ :param kwargs: Other key word arguments passed to :class:`Tokenizer`.
82
+ """
83
+ base_tokenizer = BaseTokenizer.from_pretrained(identifier)
84
+ eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
85
+ return cls(base_tokenizer, eos_token_id, **kwargs)
86
+
87
+ @classmethod
88
+ def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer:
89
+ """
90
+ Initialize a tokenizer from a file.
91
+
92
+ You can create those files with ``BaseTokenizer.save()``.
93
+
94
+ :param filename: The name of a file containing a tokenizer specification.
95
+ :param kwargs: Other key word arguments passed to :class:`Tokenizer`.
96
+ """
97
+ base_tokenizer = BaseTokenizer.from_file(filename)
98
+ eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
99
+ return cls(base_tokenizer, eos_token_id, **kwargs)
100
+
101
+ @classmethod
102
+ def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer:
103
+ """
104
+ Load a tokenizer from a checkpoint.
105
+ """
106
+ from cached_path import cached_path
107
+
108
+ # Load configs.
109
+ config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml"))
110
+ tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer")
111
+ model_config = ModelConfig.load(config_path, key="model")
112
+
113
+ # Initialize tokenizer and validate vocab size.
114
+ if Path(tokenizer_config.identifier).is_file():
115
+ tokenizer = cls.from_file(
116
+ tokenizer_config.identifier,
117
+ eos_token_id=model_config.eos_token_id,
118
+ pad_token_id=model_config.pad_token_id,
119
+ )
120
+ else:
121
+ tokenizer = cls.from_pretrained(
122
+ tokenizer_config.identifier,
123
+ eos_token_id=model_config.eos_token_id,
124
+ pad_token_id=model_config.pad_token_id,
125
+ )
126
+ if model_config.vocab_size != tokenizer.vocab_size:
127
+ raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
128
+ return tokenizer
129
+
130
+ def add_special_tokens(self, input_ids: List[int]) -> List[int]:
131
+ """
132
+ Add special tokens in-place (if not already present) to the given token IDs.
133
+ """
134
+ if not input_ids or input_ids[-1] != self.eos_token_id:
135
+ input_ids.append(self.eos_token_id)
136
+ return input_ids
137
+
138
+ def num_special_tokens_to_add(self, is_pair: bool = False) -> int:
139
+ return 2 if is_pair else 1
140
+
141
+ def _truncate(
142
+ self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection
143
+ ) -> list[int]:
144
+ if truncate_to is None or len(input_ids) <= truncate_to:
145
+ return input_ids
146
+ elif direction == TruncationDirection.left:
147
+ return input_ids[len(input_ids) - truncate_to :]
148
+ else:
149
+ return input_ids[: -(len(input_ids) - truncate_to)]
150
+
151
+ def encode(self, input: str, add_special_tokens: bool = True) -> List[int]:
152
+ """
153
+ Encode a string into token IDs.
154
+ """
155
+ return self.encode_batch([input], add_special_tokens=add_special_tokens)[0]
156
+
157
+ def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]:
158
+ """
159
+ Encode a batch of strings into token IDs.
160
+ """
161
+ truncate_to = self.truncate_to
162
+ if truncate_to is not None and add_special_tokens:
163
+ truncate_to -= self.num_special_tokens_to_add(False)
164
+
165
+ batch_encoding = self.base_tokenizer.encode_batch(inputs)
166
+
167
+ all_input_ids = []
168
+ for encoding in batch_encoding:
169
+ input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction)
170
+ if add_special_tokens:
171
+ input_ids = self.add_special_tokens(input_ids)
172
+ all_input_ids.append(input_ids)
173
+
174
+ return all_input_ids
175
+
176
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
177
+ """
178
+ Decode a list of token IDs to a string.
179
+ """
180
+ return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
model/torch_util.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from typing import Optional, TypeVar
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ def seed_all(seed: int):
12
+ """Seed all rng objects."""
13
+ import random
14
+
15
+ import numpy as np
16
+
17
+ if seed < 0 or seed > 2**32 - 1:
18
+ raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]")
19
+ random.seed(seed)
20
+ np.random.seed(seed)
21
+ torch.manual_seed(seed)
22
+ # torch.manual_seed may call manual_seed_all but calling it again here
23
+ # to make sure it gets called at least once
24
+ torch.cuda.manual_seed_all(seed)
25
+
26
+
27
+ def is_distributed() -> bool:
28
+ return dist.is_available() and dist.is_initialized()
29
+
30
+
31
+ def get_node_rank() -> int:
32
+ return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size())
33
+
34
+
35
+ def get_world_size() -> int:
36
+ if is_distributed():
37
+ return dist.get_world_size()
38
+ else:
39
+ return 1
40
+
41
+
42
+ def get_local_world_size() -> int:
43
+ return int(os.environ.get("LOCAL_WORLD_SIZE") or 1)
44
+
45
+
46
+ def get_global_rank() -> int:
47
+ return int(os.environ.get("RANK") or dist.get_rank())
48
+
49
+
50
+ def get_local_rank() -> int:
51
+ return int(os.environ.get("LOCAL_RANK") or 0)
52
+
53
+
54
+ def get_fs_local_rank() -> int:
55
+ """Get the local rank per filesystem, meaning that, regardless of the number of nodes,
56
+ if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`,
57
+ but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`.
58
+ """
59
+ return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())
60
+
61
+
62
+ def move_to_device(o: T, device: torch.device) -> T:
63
+ if isinstance(o, torch.Tensor):
64
+ return o.to(device) # type: ignore[return-value]
65
+ elif isinstance(o, dict):
66
+ return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
67
+ elif isinstance(o, list):
68
+ return [move_to_device(x, device) for x in o] # type: ignore[return-value]
69
+ elif isinstance(o, tuple):
70
+ return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value]
71
+ else:
72
+ return o
73
+
74
+
75
+ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
76
+ """
77
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
78
+ is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
79
+ """
80
+ if check_neg_inf:
81
+ x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
82
+ if check_pos_inf:
83
+ x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
84
+
85
+
86
+ def get_default_device() -> torch.device:
87
+ if torch.cuda.is_available() and torch.cuda.is_initialized():
88
+ return torch.device("cuda")
89
+ else:
90
+ return torch.device("cpu")
91
+
92
+
93
+ def barrier() -> None:
94
+ if is_distributed():
95
+ dist.barrier()
96
+
97
+
98
+ def peak_gpu_memory(reset: bool = False) -> Optional[float]:
99
+ """
100
+ Get the peak GPU memory usage in MB across all ranks.
101
+ Only rank 0 will get the final result.
102
+ """
103
+ if not torch.cuda.is_available():
104
+ return None
105
+
106
+ device = torch.device("cuda")
107
+ peak_mb = torch.cuda.max_memory_allocated(device) / 1000000
108
+ if is_distributed():
109
+ peak_mb_tensor = torch.tensor(peak_mb, device=device)
110
+ dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX)
111
+ peak_mb = peak_mb_tensor.item()
112
+
113
+ if reset:
114
+ # Reset peak stats.
115
+ torch.cuda.reset_max_memory_allocated(device)
116
+
117
+ return peak_mb
118
+
119
+
120
+ V = TypeVar("V", bool, int, float)
121
+
122
+
123
+ def synchronize_value(value: V, device: torch.device) -> V:
124
+ if dist.is_available() and dist.is_initialized():
125
+ value_tensor = torch.tensor(value, device=device)
126
+ dist.broadcast(value_tensor, 0)
127
+ return value_tensor.item() # type: ignore
128
+ else:
129
+ return value
130
+
131
+
132
+ def synchronize_flag(flag: bool, device: torch.device) -> bool:
133
+ return synchronize_value(flag, device)
134
+
135
+
136
+ def gc_cuda():
137
+ gc.collect()
138
+ if torch.cuda.is_available():
139
+ torch.cuda.empty_cache()
model/train.py ADDED
@@ -0,0 +1,1231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import cProfile
4
+ import gc
5
+ import logging
6
+ import math
7
+ import os
8
+ import random
9
+ import shutil
10
+ import time
11
+ from collections import deque
12
+ from dataclasses import dataclass, field
13
+ from itertools import islice
14
+ from pathlib import Path
15
+ from pstats import SortKey
16
+ from typing import Any, Callable, Deque, Dict, List, Optional, TextIO, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.distributed as dist
21
+ import torch.nn.functional as F
22
+ import wandb
23
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
24
+ from torch.utils.data import DataLoader
25
+
26
+ from .aliases import PathOrStr
27
+ from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
28
+ from .config import (
29
+ CheckpointType,
30
+ SchedulerUnits,
31
+ ShardedCheckpointerType,
32
+ SpeedMonitorConfig,
33
+ TrainConfig,
34
+ )
35
+ from .data import IterableDataset
36
+ from .eval import Evaluator
37
+ from .exceptions import OLMoConfigurationError
38
+ from .model import OLMo
39
+ from .optim import Optimizer, Scheduler
40
+ from .torch_util import (
41
+ barrier,
42
+ gc_cuda,
43
+ get_fs_local_rank,
44
+ get_global_rank,
45
+ get_world_size,
46
+ move_to_device,
47
+ peak_gpu_memory,
48
+ synchronize_flag,
49
+ synchronize_value,
50
+ )
51
+ from .util import upload
52
+
53
+ __all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]
54
+
55
+ log = logging.getLogger(__name__)
56
+
57
+
58
+ @dataclass
59
+ class SpeedMonitor:
60
+ cfg: SpeedMonitorConfig
61
+ start_times: Deque[float] = field(default_factory=lambda: deque([]))
62
+ global_total_tokens: int = 0
63
+ device_interval_tokens: Deque[int] = field(default_factory=lambda: deque([]))
64
+
65
+ def batch_start(self, global_total_tokens: int, device_batch_num_tokens: int, record: bool = True) -> None:
66
+ self.global_total_tokens = global_total_tokens
67
+ if record:
68
+ if len(self.start_times) >= self.cfg.window_size:
69
+ self.start_times.popleft()
70
+ self.device_interval_tokens.popleft()
71
+ self.start_times.append(time.monotonic())
72
+ self.device_interval_tokens.append(device_batch_num_tokens)
73
+
74
+ def reset(self) -> None:
75
+ self.start_times.clear()
76
+ self.device_interval_tokens.clear()
77
+
78
+ def check(self) -> Dict[str, float]:
79
+ metrics: Dict[str, float] = {"throughput/total_tokens": self.global_total_tokens}
80
+ if self.start_times:
81
+ interval_seconds = time.monotonic() - self.start_times[0]
82
+ interval_batches = len(self.start_times)
83
+ interval_tokens = sum(self.device_interval_tokens)
84
+ metrics["throughput/device/tokens_per_second"] = interval_tokens / interval_seconds
85
+ metrics["throughput/device/batches_per_second"] = interval_batches / interval_seconds
86
+ return metrics
87
+
88
+
89
+ @dataclass
90
+ class LRMonitor:
91
+ optim: torch.optim.Optimizer
92
+
93
+ def check(self) -> Dict[str, float]:
94
+ lrs = [group["lr"] for group in self.optim.param_groups]
95
+ return {f"optim/learning_rate_group{idx}": lr for idx, lr in enumerate(lrs)}
96
+
97
+
98
+ def cross_entropy_loss(
99
+ logits, labels, ignore_index: int = -100, reduction: str = "mean", compute_z_loss: bool = False
100
+ ):
101
+ loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)
102
+
103
+ if not compute_z_loss:
104
+ return loss, None
105
+
106
+ z_squared = logits.logsumexp(-1).pow(2)
107
+ if reduction == "mean":
108
+ z_squared = (z_squared * (labels != ignore_index)).mean()
109
+ elif reduction == "sum":
110
+ z_squared = (z_squared * (labels != ignore_index)).sum()
111
+
112
+ z_loss = 1e-4 * z_squared
113
+
114
+ return loss, z_loss
115
+
116
+
117
+ @dataclass
118
+ class Trainer:
119
+ cfg: TrainConfig
120
+ model: OLMo
121
+ fsdp_model: FSDP
122
+ optim: Optimizer
123
+ scheduler: Scheduler
124
+ train_loader: DataLoader
125
+ device: torch.device
126
+ evaluators: List[Evaluator]
127
+ epoch: Optional[int] = None
128
+ global_step: int = 0
129
+ global_train_examples_seen_this_epoch: int = 0
130
+ """Tracks the global number of training examples seen in the current epoch for the purpose of restoring
131
+ the data loader position on restarts."""
132
+ global_train_tokens_seen: int = 0
133
+ """Tracks the global total number of tokens trained on."""
134
+ checkpoints: List[Path] = field(default_factory=list)
135
+ unsharded_checkpoints: List[Path] = field(default_factory=list)
136
+ ephemeral_checkpoints: List[Path] = field(default_factory=list)
137
+ min_train_loss: float = float("inf")
138
+ cur_train_loss: float = float("inf")
139
+ indices_file: Optional[TextIO] = None
140
+ _start_time: float = 0.0
141
+ _gc_init_state: bool = True
142
+ loss_fn: Callable[..., torch.Tensor] = field(default_factory=lambda: cross_entropy_loss) # type: ignore
143
+ last_sharded_checkpoint_step: Optional[int] = None
144
+ last_unsharded_checkpoint_step: Optional[int] = None
145
+
146
+ def __post_init__(self):
147
+ if self.cfg.fused_loss:
148
+ from flash_attn.ops.triton.cross_entropy import ( # type: ignore
149
+ cross_entropy_loss,
150
+ )
151
+
152
+ def fused_loss_fn(
153
+ logits, labels, ignore_index: int = -100, reduction: str = "mean", compute_z_loss: bool = False
154
+ ):
155
+ loss, z_loss = cross_entropy_loss(
156
+ logits,
157
+ labels,
158
+ label_smoothing=0.0,
159
+ logit_scale=1.0,
160
+ lse_square_scale=0.0,
161
+ ignored_index=ignore_index,
162
+ inplace_backward=False,
163
+ process_group=None,
164
+ )
165
+
166
+ mask = labels != ignore_index
167
+
168
+ if reduction == "mean":
169
+ loss = loss.sum() / mask.sum()
170
+ elif reduction == "sum":
171
+ loss = loss.sum()
172
+ else:
173
+ loss = loss
174
+
175
+ if not compute_z_loss:
176
+ return loss, None
177
+
178
+ if reduction == "mean":
179
+ z_loss = z_loss.sum() / mask.sum()
180
+ elif reduction == "sum":
181
+ z_loss = z_loss.sum()
182
+ else:
183
+ z_loss = z_loss
184
+
185
+ return loss, z_loss
186
+
187
+ self.loss_fn = fused_loss_fn
188
+
189
+ @property
190
+ def dataset(self) -> IterableDataset:
191
+ assert isinstance(self.train_loader.dataset, IterableDataset)
192
+ return self.train_loader.dataset
193
+
194
+ @property
195
+ def tokens_per_batch(self) -> int:
196
+ return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length
197
+
198
+ @property
199
+ def batches_per_epoch(self) -> int:
200
+ return self.dataset.total_size // self.cfg.global_train_batch_size
201
+
202
+ @property
203
+ def max_epochs(self) -> int:
204
+ if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
205
+ return int(self.cfg.max_duration[:-2].strip())
206
+ else:
207
+ return 1
208
+
209
+ @property
210
+ def max_steps(self) -> int:
211
+ if isinstance(self.cfg.max_duration, int):
212
+ return self.cfg.max_duration
213
+ elif isinstance(self.cfg.max_duration, str):
214
+ if self.cfg.max_duration.endswith("T"):
215
+ # convert to float *first* to handle scientific notation
216
+ max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
217
+ tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
218
+ steps_remaining = tokens_remaining // self.tokens_per_batch
219
+ return self.global_step + steps_remaining
220
+ elif self.cfg.max_duration.endswith("ep"):
221
+ max_epochs = int(self.cfg.max_duration[:-2].strip())
222
+ return max_epochs * self.batches_per_epoch
223
+ else:
224
+ # convert to float *first* to handle scientific notation
225
+ return int(float(self.cfg.max_duration))
226
+ else:
227
+ raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
228
+
229
+ @property
230
+ def max_tokens(self) -> int:
231
+ if isinstance(self.cfg.max_duration, int):
232
+ return (
233
+ self.global_train_tokens_seen
234
+ + max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
235
+ )
236
+ elif isinstance(self.cfg.max_duration, str):
237
+ if self.cfg.max_duration.endswith("T"):
238
+ # convert to float *first* to handle scientific notation
239
+ return int(float(self.cfg.max_duration[:-1].strip()))
240
+ elif self.cfg.max_duration.endswith("ep"):
241
+ max_epochs = int(self.cfg.max_duration[:-2].strip())
242
+ return max_epochs * self.batches_per_epoch * self.tokens_per_batch
243
+ else:
244
+ # convert to float *first* to handle scientific notation
245
+ return (
246
+ self.global_train_tokens_seen
247
+ + max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
248
+ )
249
+ else:
250
+ raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
251
+
252
+ @property
253
+ def scheduler_current(self) -> int:
254
+ if self.cfg.scheduler.units == SchedulerUnits.steps:
255
+ return self.global_step
256
+ elif self.cfg.scheduler.units == SchedulerUnits.tokens:
257
+ return self.global_train_tokens_seen
258
+ else:
259
+ raise NotImplementedError(self.cfg.scheduler.units)
260
+
261
+ @property
262
+ def scheduler_max(self) -> int:
263
+ if self.cfg.scheduler.units == SchedulerUnits.steps:
264
+ return self.max_steps
265
+ elif self.cfg.scheduler.units == SchedulerUnits.tokens:
266
+ return self.max_tokens
267
+ else:
268
+ raise NotImplementedError(self.cfg.scheduler.units)
269
+
270
+ def trainer_state_dict(self) -> Dict[str, Any]:
271
+ return {
272
+ "epoch": self.epoch,
273
+ "global_step": self.global_step,
274
+ "global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
275
+ "global_train_tokens_seen": self.global_train_tokens_seen,
276
+ "world_size": get_world_size(),
277
+ "checkpoints": self.checkpoints,
278
+ "unsharded_checkpoints": self.unsharded_checkpoints,
279
+ "ephemeral_checkpoints": self.ephemeral_checkpoints,
280
+ "rng": {
281
+ "python": random.getstate(),
282
+ "numpy": np.random.get_state(),
283
+ "torch": torch.random.get_rng_state(),
284
+ "cuda": torch.cuda.get_rng_state(),
285
+ },
286
+ }
287
+
288
+ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
289
+ # Checkpoint paths.
290
+ self.checkpoints = [
291
+ path
292
+ for path in state_dict["checkpoints"]
293
+ if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
294
+ ]
295
+ self.unsharded_checkpoints = [
296
+ path
297
+ for path in state_dict["unsharded_checkpoints"]
298
+ if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
299
+ ]
300
+ self.ephemeral_checkpoints = [
301
+ path
302
+ for path in state_dict.get("ephemeral_checkpoints", [])
303
+ if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
304
+ ]
305
+
306
+ # Dataset / dataloader position.
307
+ checkpoint_epoch = state_dict.get("epoch", 0)
308
+ self.global_step = state_dict["global_step"]
309
+ self.global_train_examples_seen_this_epoch = state_dict.get(
310
+ "global_train_examples_seen_this_epoch",
311
+ state_dict.get( # for backwards compatibility
312
+ "global_train_examples_seen",
313
+ state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size,
314
+ ),
315
+ )
316
+ self.global_train_tokens_seen = state_dict.get(
317
+ "global_train_tokens_seen",
318
+ state_dict.get("global_data_step", self.global_step) # for backwards compatibility
319
+ * self.cfg.global_train_batch_size
320
+ * self.cfg.model.max_sequence_length,
321
+ )
322
+
323
+ if not self.cfg.restore_dataloader:
324
+ self.epoch = 0
325
+ self.global_train_tokens_seen = 0
326
+ self.global_train_examples_seen_this_epoch = 0
327
+ elif self.epoch is None:
328
+ self.epoch = checkpoint_epoch
329
+ elif checkpoint_epoch != self.epoch:
330
+ log.info(f"Starting new epoch (epoch = {self.epoch})")
331
+ self.global_train_examples_seen_this_epoch = 0
332
+
333
+ if self.cfg.fast_forward_batches:
334
+ log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
335
+ # Technically we don't "see" these batches that we fast-forward through, but we use
336
+ # this variable to update the position of the dataset so we need to include them here.
337
+ self.global_train_examples_seen_this_epoch += (
338
+ self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
339
+ )
340
+ # NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
341
+ # that variable is meant to track the actual number of tokens trained on.
342
+
343
+ if self.global_train_examples_seen_this_epoch > 0:
344
+ assert isinstance(self.dataset, IterableDataset)
345
+ log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
346
+ self.dataset.start_index = self.global_train_examples_seen_this_epoch
347
+
348
+ # Reset learning rate and weight decay to the values from the config, not the checkpoint.
349
+ log.info("Resetting learning rate...")
350
+ new_learning_rate = self.scheduler.get_lr(
351
+ self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
352
+ )
353
+ for group in self.optim.param_groups:
354
+ group["lr"] = new_learning_rate
355
+ group["initial_lr"] = self.cfg.optimizer.learning_rate
356
+ if "weight_decay" in group and group["weight_decay"] > 0.0:
357
+ group["weight_decay"] = self.cfg.optimizer.weight_decay
358
+
359
+ # RNG states.
360
+ if "rng" in state_dict and state_dict.get("world_size", get_world_size()) == get_world_size():
361
+ log.info("Restoring RNG states...")
362
+ rng_state = state_dict["rng"]
363
+ self.restore_rng_state(rng_state)
364
+ else:
365
+ log.warning(
366
+ "Trainer will not restore RNG states since the RNG states in the checkpoint are missing or invalid. "
367
+ "This typically happens when restoring from an unsharded checkpoint or a checkpoint that was saved "
368
+ "with a different world size. If that's the case you can safely ignore this warning."
369
+ )
370
+
371
+ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
372
+ random.setstate(rng_state["python"])
373
+ np.random.set_state(rng_state["numpy"])
374
+ torch.set_rng_state(rng_state["torch"])
375
+ torch.cuda.set_rng_state(rng_state["cuda"])
376
+
377
+ def _save_checkpoint(
378
+ self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
379
+ ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
380
+ if checkpoint_type == CheckpointType.sharded:
381
+ suffix = ""
382
+ current_checkpoints = self.checkpoints
383
+ link_latest = get_fs_local_rank() == 0
384
+ num_checkpoints_to_keep = self.cfg.save_num_checkpoints_to_keep
385
+ elif checkpoint_type == CheckpointType.unsharded:
386
+ suffix = "-unsharded"
387
+ current_checkpoints = self.unsharded_checkpoints
388
+ link_latest = get_global_rank() == 0
389
+ num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
390
+ elif checkpoint_type == CheckpointType.sharded_ephemeral:
391
+ suffix = ""
392
+ current_checkpoints = self.ephemeral_checkpoints
393
+ link_latest = get_fs_local_rank() == 0
394
+ num_checkpoints_to_keep = 1
395
+ else:
396
+ raise NotImplementedError(checkpoint_type)
397
+
398
+ # Zero-gradients to avoid gathering them.
399
+ self.optim.zero_grad(set_to_none=True)
400
+
401
+ # Flush data indices file.
402
+ # TODO: upload the indices files?
403
+ if self.indices_file is not None:
404
+ self.indices_file.flush()
405
+
406
+ checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}{suffix}"
407
+ remote_checkpoint_dir: Optional[str] = None
408
+ if self.cfg.remote_save_folder is not None:
409
+ remote_checkpoint_dir = f"{self.cfg.remote_save_folder.rstrip('/')}/{checkpoint_dir.name}"
410
+ current_checkpoints.append(checkpoint_dir)
411
+
412
+ # Save the checkpoint.
413
+ try:
414
+ checkpointer.save_checkpoint(
415
+ checkpoint_dir,
416
+ self.fsdp_model,
417
+ self.optim,
418
+ self.trainer_state_dict(),
419
+ upload_to=remote_checkpoint_dir,
420
+ )
421
+ except FileExistsError:
422
+ raise OLMoConfigurationError(
423
+ f"Checkpoint for step {self.global_step} already exists, use --save-overwrite to overwrite it"
424
+ )
425
+
426
+ if link_latest:
427
+ if get_global_rank() == 0:
428
+ # Link to 'latest'.
429
+ latest_path = Path(self.cfg.save_folder) / f"latest{suffix}"
430
+ latest_path.unlink(missing_ok=True)
431
+ try:
432
+ latest_path.symlink_to(checkpoint_dir.name, target_is_directory=True)
433
+ except FileExistsError:
434
+ # Same as above, caught when another (file-system) local rank 0 has already made the 'latest' symlink.
435
+ # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
436
+ # file-systems.
437
+ if latest_path.resolve().name != checkpoint_dir.name:
438
+ raise
439
+
440
+ # Remove old checkpoints.
441
+ if num_checkpoints_to_keep > 0:
442
+ while len(current_checkpoints) > num_checkpoints_to_keep:
443
+ self.remove_checkpoint(0, checkpoint_type)
444
+
445
+ barrier()
446
+
447
+ if remote_checkpoint_dir is not None:
448
+ return remote_checkpoint_dir, checkpoint_dir
449
+ else:
450
+ return checkpoint_dir, None
451
+
452
+ def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
453
+ checkpointer = build_sharded_checkpointer(self.cfg)
454
+ result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
455
+ self.last_sharded_checkpoint_step = self.global_step
456
+ return result
457
+
458
+ def save_ephemeral_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
459
+ checkpointer = build_sharded_checkpointer(self.cfg)
460
+ result = self._save_checkpoint(checkpointer, CheckpointType.sharded_ephemeral)
461
+ self.last_sharded_checkpoint_step = self.global_step
462
+ return result
463
+
464
+ def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
465
+ oldest_checkpoint = checkpoints.pop(idx)
466
+ barrier()
467
+ if get_global_rank() == 0 and oldest_checkpoint.is_dir():
468
+ shutil.rmtree(oldest_checkpoint, ignore_errors=True)
469
+ latest_path = Path(self.cfg.save_folder) / "latest"
470
+ if latest_path.resolve() == oldest_checkpoint.resolve():
471
+ latest_path.unlink()
472
+ barrier()
473
+
474
+ def remove_sharded_checkpoint(self, idx: int = 0):
475
+ self._remove_sharded_checkpoint(idx, self.checkpoints)
476
+
477
+ def remove_ephemeral_checkpoint(self, idx: int = 0):
478
+ self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)
479
+
480
+ def restore_sharded_checkpoint(
481
+ self,
482
+ load_path: PathOrStr,
483
+ local_cache: Optional[PathOrStr] = None,
484
+ *,
485
+ load_optimizer_state: bool = True,
486
+ load_trainer_state: bool = True,
487
+ sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
488
+ ):
489
+ # Zero-gradients to avoid gathering them.
490
+ self.optim.zero_grad(set_to_none=True)
491
+ checkpointer = build_sharded_checkpointer(self.cfg, name=sharded_checkpointer)
492
+ trainer_state = checkpointer.restore_checkpoint(
493
+ load_path,
494
+ self.fsdp_model,
495
+ self.optim,
496
+ local_cache=local_cache,
497
+ load_optimizer_state=load_optimizer_state,
498
+ )
499
+ if load_trainer_state:
500
+ self.load_trainer_state_dict(trainer_state)
501
+ barrier()
502
+
503
+ def save_unsharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
504
+ checkpointer = FullCheckpointer(self.cfg)
505
+ result = self._save_checkpoint(checkpointer, CheckpointType.unsharded)
506
+ self.last_unsharded_checkpoint_step = self.global_step
507
+ return result
508
+
509
+ def remove_unsharded_checkpoint(self, idx: int = 0):
510
+ barrier()
511
+ oldest_checkpoint = self.unsharded_checkpoints.pop(idx)
512
+ if get_global_rank() == 0 and oldest_checkpoint.is_dir():
513
+ shutil.rmtree(oldest_checkpoint, ignore_errors=True)
514
+ latest_path = Path(self.cfg.save_folder) / "latest-unsharded"
515
+ if latest_path.resolve() == oldest_checkpoint.resolve():
516
+ latest_path.unlink()
517
+ barrier()
518
+
519
+ def restore_unsharded_checkpoint(
520
+ self,
521
+ load_path: PathOrStr,
522
+ local_cache: Optional[PathOrStr] = None,
523
+ *,
524
+ load_optimizer_state: bool = True,
525
+ load_trainer_state: bool = True,
526
+ ):
527
+ # Zero-gradients to avoid gathering them.
528
+ self.optim.zero_grad(set_to_none=True)
529
+ checkpointer = FullCheckpointer(self.cfg)
530
+ trainer_state = checkpointer.restore_checkpoint(
531
+ load_path,
532
+ self.fsdp_model,
533
+ self.optim,
534
+ local_cache=local_cache,
535
+ load_optimizer_state=load_optimizer_state,
536
+ )
537
+ if load_trainer_state:
538
+ self.load_trainer_state_dict(trainer_state)
539
+ barrier()
540
+
541
+ def save_checkpoint(
542
+ self, checkpoint_type: CheckpointType = CheckpointType.sharded
543
+ ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
544
+ result: Tuple[PathOrStr, Optional[PathOrStr]]
545
+ if checkpoint_type == CheckpointType.sharded:
546
+ result = self.save_sharded_checkpoint()
547
+ elif checkpoint_type == CheckpointType.unsharded:
548
+ result = self.save_unsharded_checkpoint()
549
+ elif checkpoint_type == CheckpointType.sharded_ephemeral:
550
+ result = self.save_ephemeral_checkpoint()
551
+ else:
552
+ raise NotImplementedError(checkpoint_type)
553
+
554
+ gc_cuda()
555
+ return result
556
+
557
+ def restore_checkpoint(
558
+ self,
559
+ load_path: PathOrStr,
560
+ *,
561
+ checkpoint_type: Optional[CheckpointType] = None,
562
+ local_cache: Optional[PathOrStr] = None,
563
+ load_optimizer_state: bool = True,
564
+ load_trainer_state: bool = True,
565
+ sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
566
+ ):
567
+ if checkpoint_type == CheckpointType.unsharded or (
568
+ checkpoint_type is None and str(load_path).rstrip("/").endswith("-unsharded")
569
+ ):
570
+ self.restore_unsharded_checkpoint(
571
+ load_path,
572
+ local_cache=local_cache,
573
+ load_optimizer_state=load_optimizer_state,
574
+ load_trainer_state=load_trainer_state,
575
+ )
576
+ elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
577
+ self.restore_sharded_checkpoint(
578
+ load_path,
579
+ local_cache=local_cache,
580
+ load_optimizer_state=load_optimizer_state,
581
+ load_trainer_state=load_trainer_state,
582
+ sharded_checkpointer=sharded_checkpointer,
583
+ )
584
+ elif checkpoint_type is not None:
585
+ raise NotImplementedError(checkpoint_type)
586
+
587
+ gc_cuda()
588
+
589
+ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
590
+ if checkpoint_type == CheckpointType.sharded:
591
+ self.remove_sharded_checkpoint(idx=idx)
592
+ elif checkpoint_type == CheckpointType.unsharded:
593
+ self.remove_unsharded_checkpoint(idx=idx)
594
+ elif checkpoint_type == CheckpointType.sharded_ephemeral:
595
+ self.remove_ephemeral_checkpoint(idx=idx)
596
+ else:
597
+ raise NotImplementedError(checkpoint_type)
598
+
599
+ def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
600
+ # Labels are just input IDs shifted to the left (first item is ignored).
601
+ labels, label_mask, attention_mask = (
602
+ batch["input_ids"].clone(),
603
+ batch.get("label_mask"),
604
+ batch.get("attention_mask"),
605
+ )
606
+ if label_mask is not None:
607
+ labels.masked_fill_(~label_mask, -100)
608
+ if attention_mask is not None:
609
+ labels.masked_fill_(attention_mask == 0.0, -100)
610
+ return labels[..., 1:].contiguous()
611
+
612
+ def model_forward(
613
+ self, batch: Dict[str, Any], loss_reduction: str = "mean", compute_z_loss: bool = False
614
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
615
+ # shape: (batch_size, seq_len, vocab_size)
616
+ logits = self.fsdp_model(
617
+ input_ids=batch["input_ids"],
618
+ attention_mask=batch.get("attention_mask"),
619
+ attention_bias=batch.get("attention_bias"),
620
+ ).logits
621
+ logits_for_loss = logits[..., :-1, :].contiguous()
622
+ # shape: (batch_size * seq_len, vocab_size)
623
+ logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
624
+ # shape: (batch_size, seq_len)
625
+ labels = self.get_labels(batch)
626
+ # shape: (batch_size * seq_len,)
627
+ labels = labels.view(-1)
628
+ ce_loss, z_loss = self.loss_fn(
629
+ logits_for_loss, labels, ignore_index=-100, reduction=loss_reduction, compute_z_loss=compute_z_loss
630
+ )
631
+ if loss_reduction == "none":
632
+ # Reshape (batch_size * seq_len,) -> (batch_size, seq_len)
633
+ ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
634
+ if z_loss is not None:
635
+ z_loss = z_loss.view(batch["input_ids"].shape[0], -1)
636
+ return ce_loss, z_loss, logits
637
+
638
+ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
639
+ # Split into micro-batches.
640
+ # print(f"Start preparing micro-batches at step {self.global_step}") if get_global_rank() == 0 else None
641
+ micro_batches = self.split_batch(batch)
642
+
643
+ # In case this helps with memory utilization.
644
+ del batch
645
+
646
+ ce_batch_loss = torch.tensor(0.0, device=self.device)
647
+ z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
648
+ # print(f"Start training micro-batches at step {self.global_step}") if get_global_rank() == 0 else None
649
+ for micro_batch in micro_batches:
650
+ with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
651
+ # Run forward pass.)
652
+ # print(f"Start forward pass at step {self.global_step}") if get_global_rank() == 0 else None
653
+ # print(f"micro_batch['input_ids'].shape: {micro_batch['input_ids'].shape}") if get_global_rank() == 0 else None
654
+ # print(f"min micro_batch['input_ids']: {micro_batch['input_ids'].min()}") if get_global_rank() == 0 else None
655
+ # print(f"max micro_batch['input_ids']: {micro_batch['input_ids'].max()}") if get_global_rank() == 0 else None
656
+ if get_fs_local_rank() == 0 and (self.global_step == 1421 or self.global_step == 1422 or self.global_step == 1423):
657
+ # save micro batch input_ids to file, which is a list of integers
658
+ with open(f"micro_batch_step{self.global_step}.txt", "w") as f:
659
+ for i in range(micro_batch["input_ids"].shape[0]):
660
+ f.write(f"{micro_batch['input_ids'][i].tolist()}\n")
661
+ ce_loss, z_loss, logits = self.model_forward(
662
+ micro_batch, compute_z_loss=self.cfg.softmax_auxiliary_loss
663
+ )
664
+ # print(f"End micro_batch at step {self.global_step} with ce_loss {ce_loss}, z_loss {z_loss}") if get_global_rank() == 0 else None
665
+ ce_loss = ce_loss / len(micro_batches)
666
+
667
+ # In case this helps with memory utilization.
668
+ del micro_batch
669
+
670
+ # Update overall CE batch loss.
671
+ ce_batch_loss += ce_loss.detach()
672
+
673
+ # Get loss to optimize for.
674
+ if self.cfg.softmax_auxiliary_loss:
675
+ assert z_loss is not None
676
+ assert z_batch_loss is not None
677
+ z_loss = z_loss / len(micro_batches)
678
+ loss = ce_loss + z_loss
679
+
680
+ # Update overall Z batch loss.
681
+ z_batch_loss += z_loss.detach()
682
+ else:
683
+ loss = ce_loss
684
+
685
+ del logits
686
+ # print(f"---before micro_batch backward at step {self.global_step}") if get_global_rank() == 0 else None
687
+ # print(f" loss value: {loss}") if get_global_rank() == 0 else None
688
+ # Run backward pass.
689
+ loss.backward()
690
+ # print(f"---after micro_batch backward at step {self.global_step}") if get_global_rank() == 0 else None
691
+ return ce_batch_loss, z_batch_loss
692
+
693
+ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
694
+ metrics: Dict[str, float] = {}
695
+
696
+ # Write data-indices to file.
697
+ if self.indices_file is not None and "index" in batch:
698
+ indices = "\t".join(str(int(i)) for i in batch["index"])
699
+ self.indices_file.write(f"{self.global_step}\t{indices}\n")
700
+
701
+ # Zero-gradients.
702
+ self.optim.zero_grad(set_to_none=True)
703
+
704
+ # Move tensors to the right device.
705
+ batch = move_to_device(batch, self.device)
706
+
707
+ # Run forward-backward pass.
708
+ ce_batch_loss, z_batch_loss = self.train_batch(batch)
709
+ # Collect loss, potentially reducing over all ranks.
710
+ if reduce_global_loss:
711
+ dist.reduce(ce_batch_loss, 0)
712
+ ce_batch_loss.div_(get_world_size())
713
+ if z_batch_loss is not None:
714
+ dist.reduce(z_batch_loss, 0)
715
+ z_batch_loss.div_(get_world_size())
716
+
717
+ # Clip gradient norms and collect param/gradient/optim metrics.
718
+ should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
719
+ optim_metrics = self.optim.clip_grads_and_collect_metrics(
720
+ self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
721
+ )
722
+
723
+ # Adjust the learning rate.
724
+ for group in self.optim.param_groups:
725
+ # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
726
+ # we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
727
+ # the corresponding values from `self.cfg`.
728
+ group["lr"] = self.scheduler.get_lr(
729
+ self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
730
+ )
731
+ group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
732
+ self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
733
+ )
734
+ group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
735
+ self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
736
+ )
737
+
738
+ # Optimizer step.
739
+ self.optim.step()
740
+
741
+ # Collect metrics and check for NaN loss.
742
+ # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
743
+ if torch.isnan(ce_batch_loss):
744
+ raise ValueError("nan loss encountered")
745
+ if z_batch_loss is not None and torch.isnan(z_batch_loss):
746
+ raise ValueError("nan loss encountered")
747
+ for key, value in optim_metrics.items():
748
+ metrics[f"optim/{key}"] = value.item()
749
+ self.cur_train_loss = ce_batch_loss.item()
750
+ self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
751
+ metrics["train/CrossEntropyLoss"] = self.cur_train_loss
752
+ metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
753
+ if z_batch_loss is not None:
754
+ metrics["train/ZLoss"] = z_batch_loss.item()
755
+
756
+ # Maybe collect post-step optimizer-specific metrics.
757
+ if should_log_optim_metrics_this_step:
758
+ optim_metrics = self.optim.get_post_step_metrics(self.fsdp_model)
759
+ for key, value in optim_metrics.items():
760
+ metrics[f"optim/{key}"] = value.item()
761
+
762
+ return metrics
763
+
764
+ def eval_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
765
+ with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
766
+ ce_loss, _, logits = self.model_forward(batch, loss_reduction="none")
767
+ return ce_loss.mean(dim=-1), logits
768
+
769
+ def eval_step(self, batch: Dict[str, Any], evaluator: Evaluator) -> None:
770
+ # Move tensors to the right device.
771
+ batch = move_to_device(batch, self.device)
772
+
773
+ # Run forward pass.
774
+ with torch.no_grad(): # NOTE: 'torch.inference_mode()' doesn't work with 'torch.compile()'.
775
+ ce_loss, logits = self.eval_batch(batch)
776
+
777
+ # Update metrics.
778
+ evaluator.update_metrics(
779
+ batch, ce_loss, logits
780
+ ) # batch includes all keys that the downstream evaluation needs
781
+
782
+ barrier()
783
+
784
+ def split_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
785
+ microbatch_size = self.cfg.device_train_microbatch_size
786
+ batch_size = batch["input_ids"].shape[0]
787
+ if batch_size <= microbatch_size:
788
+ return [batch]
789
+ else:
790
+ micro_batches = {}
791
+ for key, value in batch.items():
792
+ if isinstance(value, torch.Tensor):
793
+ micro_batches[key] = value.split(microbatch_size, dim=0)
794
+ elif isinstance(value, list):
795
+ micro_batches[key] = [
796
+ value[microbatch_size * i : microbatch_size * i + microbatch_size]
797
+ for i in range(math.ceil(batch_size / microbatch_size))
798
+ ]
799
+ else:
800
+ raise ValueError(f"unexpected item in batch: '{key}={value}'")
801
+ return [
802
+ {key: value[i] for key, value in micro_batches.items()} # type: ignore
803
+ for i in range(len(micro_batches["input_ids"]))
804
+ ]
805
+
806
+ def system_metrics(self) -> Dict[str, float]:
807
+ metrics = {}
808
+ if self.global_step < 3 or self.global_step % 10 == 0:
809
+ peak_gpu_mb = peak_gpu_memory()
810
+ if peak_gpu_mb is not None:
811
+ metrics["System/Peak GPU Memory (MB)"] = peak_gpu_mb
812
+ return metrics
813
+
814
+ def log_metrics_to_console(self, prefix: str, metrics: Dict[str, float]):
815
+ def format_float(value: float) -> str:
816
+ if value < 0.0001:
817
+ return str(value) # scientific notation
818
+ elif value > 1000:
819
+ return f"{int(value):,d}"
820
+ elif value > 100:
821
+ return f"{value:.1f}"
822
+ elif value > 10:
823
+ return f"{value:.2f}"
824
+ elif value > 1:
825
+ return f"{value:.3f}"
826
+ else:
827
+ return f"{value:.4f}"
828
+
829
+ log.info(
830
+ f"{prefix}\n"
831
+ + "\n".join(
832
+ [
833
+ f" {name}={format_float(value)}"
834
+ for name, value in metrics.items()
835
+ if not name.startswith("optim/") # there's too many optimizer metrics
836
+ ]
837
+ )
838
+ )
839
+
840
+ def should_log_optim_metrics_this_step(self) -> bool:
841
+ if self.cfg.wandb is None:
842
+ # We only log optimizer-specific metrics to W&B, since there are usually too many metrics
843
+ # to log to the console.
844
+ return False
845
+ optim_log_interval = self.cfg.optimizer.metrics_log_interval
846
+ if optim_log_interval is None:
847
+ optim_log_interval = self.cfg.wandb.log_interval
848
+ else:
849
+ optim_log_interval = max(optim_log_interval, self.cfg.wandb.log_interval)
850
+ return self.global_step % optim_log_interval == 0
851
+
852
+ def should_log_this_step(self) -> bool:
853
+ if self.global_step % self.cfg.console_log_interval == 0:
854
+ return True
855
+ elif self.cfg.wandb is not None and self.global_step % self.cfg.wandb.log_interval == 0:
856
+ return True
857
+ else:
858
+ return False
859
+
860
+ def eval(self) -> Dict[str, Any]:
861
+ # Zero gradients and set model to 'eval' mode.
862
+ self.optim.zero_grad(set_to_none=True)
863
+ self.fsdp_model.eval()
864
+
865
+ eval_metrics = {}
866
+ for evaluator in self.evaluators:
867
+ log.info(f"Running evaluation for '{evaluator.label}'...")
868
+
869
+ # Reset metrics.
870
+ evaluator.reset_metrics()
871
+
872
+ # Initialize data loader iterator.
873
+ eval_batches = iter(evaluator.eval_loader)
874
+
875
+ # Adjust how many batches to evaluate on.
876
+ num_eval_batches = (
877
+ evaluator.subset_num_batches
878
+ if evaluator.subset_num_batches is not None
879
+ else self.cfg.eval_subset_num_batches
880
+ )
881
+ if num_eval_batches > 0:
882
+ num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
883
+ eval_batches = islice(eval_batches, num_eval_batches)
884
+
885
+ # Run model over batches.
886
+ for eval_step, eval_batch in enumerate(eval_batches):
887
+ self.eval_step(eval_batch, evaluator)
888
+
889
+ # Log to console.
890
+ if eval_step + 1 == num_eval_batches or (eval_step + 1) % self.cfg.console_log_interval == 0:
891
+ log.info(f"[eval_step={eval_step + 1}/{num_eval_batches}]")
892
+
893
+ # Get final metrics.
894
+ metrics = evaluator.compute_metrics()
895
+ eval_metrics.update(metrics)
896
+ self.log_metrics_to_console(f"{evaluator.label}", metrics)
897
+
898
+ del eval_batches
899
+
900
+ return eval_metrics
901
+
902
+ def check_if_cancelled(self) -> Tuple[bool, int]:
903
+ should_cancel = False
904
+ cancel_reason: Optional[str] = None
905
+ extra_steps = 0
906
+ if get_global_rank() == 0:
907
+ if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
908
+ # First check if we've reached the training time limit.
909
+ should_cancel = True
910
+ cancel_reason = "time limit reached"
911
+ extra_steps = self.cfg.extra_steps_after_cancel
912
+ elif (
913
+ self.cfg.early_stopping_factor is not None
914
+ and self.global_step > self.cfg.scheduler.t_warmup
915
+ and self.cur_train_loss > self.cfg.early_stopping_factor * self.min_train_loss
916
+ ):
917
+ # Next check if early stopping loss criteria is met.
918
+ should_cancel = True
919
+ cancel_reason = "early stopping from loss increase"
920
+ elif wandb.run is not None and (api_key := os.environ.get("WANDB_API_KEY")) is not None:
921
+ # Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag..
922
+ # We won't see it in the run object. So we have to use the import/export API to check.
923
+ from requests.exceptions import RequestException
924
+
925
+ try:
926
+ api = wandb.Api(api_key=api_key)
927
+ run = api.run(wandb.run.path)
928
+ for tag in run.tags or []:
929
+ if tag.lower() in {"cancel", "canceled", "cancelled"}:
930
+ should_cancel = True
931
+ cancel_reason = "Weights & Biases tag"
932
+ extra_steps = self.cfg.extra_steps_after_cancel
933
+ break
934
+ except RequestException:
935
+ pass
936
+
937
+ run_canceled = synchronize_flag(should_cancel, self.device)
938
+ if run_canceled:
939
+ extra_steps = synchronize_value(extra_steps, self.device)
940
+ if cancel_reason is None:
941
+ if extra_steps > 0:
942
+ log.warning(f"Run canceled, stopping in {extra_steps} more steps...")
943
+ else:
944
+ log.warning("Run canceled")
945
+ else:
946
+ if extra_steps > 0:
947
+ log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
948
+ else:
949
+ log.warning(f"Run canceled due to {cancel_reason}")
950
+
951
+ return run_canceled, extra_steps
952
+
953
+ def fit(self):
954
+ if self.cfg.stop_after is not None:
955
+ if self.cfg.stop_at is None:
956
+ self.cfg.stop_at = self.global_step + self.cfg.stop_after
957
+ else:
958
+ self.cfg.stop_at = min(self.cfg.stop_at, self.global_step + self.cfg.stop_after)
959
+
960
+ self._start_time = time.time()
961
+ self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close.
962
+
963
+ # Disable automatic garbage collection, FSDP doesn't work well with it.
964
+ if self.cfg.gen1_gc_interval is not None:
965
+ gc.disable()
966
+
967
+ if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
968
+ eval_metrics = self.eval()
969
+ if wandb.run is not None:
970
+ wandb.log(eval_metrics, step=self.global_step)
971
+
972
+ # Set model to 'train' mode.
973
+ self.fsdp_model.train()
974
+
975
+ # Initialize monitors.
976
+ assert self.cfg.device_train_batch_size is not None
977
+ speed_monitor = SpeedMonitor(self.cfg.speed_monitor)
978
+ lr_monitor = LRMonitor(self.optim)
979
+
980
+ # Log system metrics at the start of training.
981
+ sys_metrics = self.system_metrics()
982
+ if sys_metrics:
983
+ self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
984
+ if wandb.run is not None:
985
+ wandb.log(sys_metrics, step=0)
986
+
987
+ # Python Profiler stuff
988
+ if self.cfg.python_profiling:
989
+ python_profiler = cProfile.Profile()
990
+ else:
991
+ python_profiler = None
992
+
993
+ # PyTorch Profiler stuff
994
+ if self.cfg.torch_profiling and get_global_rank() == 0:
995
+ from torch.profiler import schedule
996
+
997
+ profiling_schedule = schedule(wait=1, warmup=5, active=3, repeat=1)
998
+
999
+ def on_trace_ready(p):
1000
+ profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
1001
+ profiler_output_dir.mkdir(exist_ok=True)
1002
+
1003
+ output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32)
1004
+ log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}")
1005
+ output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
1006
+ log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")
1007
+
1008
+ p.export_chrome_trace(
1009
+ str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
1010
+ )
1011
+ if self.cfg.remote_save_folder is not None:
1012
+ upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler"
1013
+ log.info(f"Tracing complete, uploading results to '{upload_folder}'...")
1014
+ upload(trace_path, f"{upload_folder}/{trace_path.name}")
1015
+
1016
+ from torch.profiler import ProfilerActivity
1017
+
1018
+ torch_profiler = torch.profiler.profile(
1019
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
1020
+ record_shapes=False,
1021
+ profile_memory=False,
1022
+ with_stack=True,
1023
+ schedule=profiling_schedule,
1024
+ on_trace_ready=on_trace_ready,
1025
+ )
1026
+ del profiling_schedule
1027
+ else:
1028
+ import contextlib
1029
+
1030
+ torch_profiler = contextlib.nullcontext()
1031
+
1032
+ # Train.
1033
+ first_batch: bool = True
1034
+ cancel_initiated: bool = False
1035
+ stop_at: Optional[int] = self.cfg.stop_at
1036
+ save_checkpoints: bool = True
1037
+
1038
+ with torch_profiler as p:
1039
+ for epoch in range(self.epoch or 0, self.max_epochs):
1040
+ for batch in self.train_loader:
1041
+ # print(f" >>>>>>>>>>fit start with Global step: {self.global_step} <<<<<<<<<<<<<<<") if get_global_rank()==0 else None
1042
+ # Bookkeeping.
1043
+ # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
1044
+ # batches see the same number of tokens, which should be the case for language model pre-training
1045
+ # (at least when drop_last=True).
1046
+ # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
1047
+ # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
1048
+ # fail loudly.
1049
+ batch_size, seq_len = batch["input_ids"].shape
1050
+ assert seq_len == self.cfg.model.max_sequence_length
1051
+ assert batch_size == self.cfg.device_train_batch_size
1052
+ global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
1053
+ self.global_step += 1
1054
+ self.global_train_examples_seen_this_epoch += global_batch_size
1055
+ self.global_train_tokens_seen += global_batch_size * seq_len
1056
+ speed_monitor.batch_start(
1057
+ self.global_train_tokens_seen,
1058
+ batch_size * seq_len, # num tokens in batch for this device
1059
+ # We start monitoring speed after the first batch since the first
1060
+ # batch might be an outlier due to compiling and other initialization overhead.
1061
+ record=not first_batch,
1062
+ )
1063
+
1064
+ should_log_this_step = self.should_log_this_step()
1065
+
1066
+ # Run train step on batch.
1067
+ metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)
1068
+ # print(f" After train step with Global step: {self.global_step}") if get_global_rank()==0 else None
1069
+
1070
+ # Maybe collect other metrics.
1071
+ if should_log_this_step:
1072
+ # Speed metrics.
1073
+ metrics.update(speed_monitor.check())
1074
+ # System metrics.
1075
+ metrics.update(self.system_metrics())
1076
+ # Learning rate metrics.
1077
+ metrics.update(lr_monitor.check())
1078
+
1079
+ # Log metrics to console.
1080
+ if self.global_step % self.cfg.console_log_interval == 0:
1081
+ if get_global_rank() == 0:
1082
+ self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)
1083
+ else:
1084
+ log.info(f"[step={self.global_step}/{self.max_steps}]")
1085
+
1086
+ # Log metrics to W&B.
1087
+ if (
1088
+ wandb.run is not None
1089
+ and self.cfg.wandb is not None
1090
+ and self.global_step % self.cfg.wandb.log_interval == 0
1091
+ ):
1092
+ wandb.log(metrics, step=self.global_step)
1093
+
1094
+ # Check if/when run should be canceled.
1095
+ if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
1096
+ cancel_initiated, extra_steps = self.check_if_cancelled()
1097
+ if cancel_initiated:
1098
+ stop_at = (
1099
+ self.global_step + extra_steps
1100
+ if stop_at is None
1101
+ else min(self.global_step + extra_steps, stop_at)
1102
+ )
1103
+
1104
+ # Maybe save sharded checkpoint.
1105
+ if save_checkpoints and (
1106
+ cancel_initiated
1107
+ or (
1108
+ self.global_step % self.cfg.save_interval == 0
1109
+ and self.cfg.save_num_checkpoints_to_keep != 0
1110
+ )
1111
+ ):
1112
+ log.info("Saving checkpoint...")
1113
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
1114
+ log.info(f"Checkpoint saved to {checkpoint_path}")
1115
+
1116
+ # Remove any ephemeral checkpoints.
1117
+ while self.ephemeral_checkpoints:
1118
+ self.remove_ephemeral_checkpoint()
1119
+
1120
+ # Reset speed monitor so that we don't count the time taken to save checkpoints.
1121
+ speed_monitor.reset()
1122
+
1123
+ # If the run was just canceled this will be the final checkpoint.
1124
+ if cancel_initiated:
1125
+ save_checkpoints = False
1126
+ elif (
1127
+ self.cfg.save_interval_ephemeral is not None
1128
+ and self.global_step % self.cfg.save_interval_ephemeral == 0
1129
+ ):
1130
+ log.info("Saving ephemeral checkpoint...")
1131
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
1132
+ log.info(f"Checkpoint saved to {checkpoint_path}")
1133
+
1134
+ # Reset speed monitor so that we don't count the time taken to save checkpoints.
1135
+ speed_monitor.reset()
1136
+
1137
+ # Maybe save unsharded checkpoint.
1138
+ if (
1139
+ save_checkpoints
1140
+ and self.cfg.save_interval_unsharded is not None
1141
+ and self.global_step % self.cfg.save_interval_unsharded == 0
1142
+ and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
1143
+ ):
1144
+ log.info("Saving unsharded checkpoint...")
1145
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
1146
+ log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
1147
+
1148
+ # Reset speed monitor so that we don't count the time taken to save checkpoints.
1149
+ speed_monitor.reset()
1150
+
1151
+ # Maybe run evaluations.
1152
+ if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
1153
+ eval_metrics = self.eval()
1154
+
1155
+ # Log metrics to W&B.
1156
+ if wandb.run is not None:
1157
+ wandb.log(eval_metrics, step=self.global_step)
1158
+
1159
+ # Reset speed monitor so that we don't count the time taken to run evaluations.
1160
+ speed_monitor.reset()
1161
+
1162
+ # Reset model to 'train' mode.
1163
+ self.fsdp_model.train()
1164
+
1165
+ # End of batch.
1166
+ first_batch = False
1167
+ if p is not None:
1168
+ p.step()
1169
+
1170
+ if stop_at is not None and self.global_step >= stop_at:
1171
+ break
1172
+
1173
+ # Run generation 1 garbage collection.
1174
+ if self.cfg.gen1_gc_interval is not None and self.global_step % self.cfg.gen1_gc_interval == 0:
1175
+ gc.collect(1)
1176
+
1177
+ # Python Profiler stuff
1178
+ # We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
1179
+ if python_profiler is not None:
1180
+ if self.global_step == 5:
1181
+ python_profiler.enable()
1182
+ elif self.global_step == 8:
1183
+ python_profiler.disable()
1184
+ python_profiler.print_stats(sort=SortKey.CUMULATIVE)
1185
+ python_profiler = None
1186
+ else:
1187
+ log.info("Training epoch complete")
1188
+ self.epoch = epoch + 1
1189
+ self.global_train_examples_seen_this_epoch = 0
1190
+ if self.epoch < self.max_epochs:
1191
+ self.dataset.reshuffle()
1192
+ continue
1193
+
1194
+ break
1195
+
1196
+ # Save final checkpoint.
1197
+ if save_checkpoints:
1198
+ if (
1199
+ self.cfg.save_interval_unsharded is not None
1200
+ and self.last_unsharded_checkpoint_step != self.global_step
1201
+ ):
1202
+ log.info("Saving final unsharded model checkpoint...")
1203
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
1204
+ log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
1205
+ elif (
1206
+ self.cfg.save_num_checkpoints_to_keep != 0
1207
+ and self.last_sharded_checkpoint_step != self.global_step
1208
+ ):
1209
+ log.info("Saving final checkpoint...")
1210
+ checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
1211
+ log.info(f"Checkpoint saved to {checkpoint_path}")
1212
+
1213
+ def close(self, exit_code: int = 0) -> None:
1214
+ gc_cuda()
1215
+
1216
+ if self.indices_file is not None:
1217
+ self.indices_file.flush()
1218
+ self.indices_file.close()
1219
+ if self._gc_init_state:
1220
+ gc.enable()
1221
+ else:
1222
+ gc.disable()
1223
+ if wandb.run is not None:
1224
+ wandb.finish(exit_code=exit_code, quiet=True)
1225
+
1226
+ def __enter__(self) -> Trainer:
1227
+ return self
1228
+
1229
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1230
+ del exc_val, exc_tb
1231
+ self.close(0 if exc_type is None else 1)
model/util.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import socket
5
+ import sys
6
+ import time
7
+ import warnings
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ from itertools import cycle, islice
11
+ from pathlib import Path
12
+ from queue import Queue
13
+ from threading import Thread
14
+ from typing import Any, Callable, Dict, Optional, Union
15
+
16
+ import boto3
17
+ import botocore.exceptions as boto_exceptions
18
+ import rich
19
+ from botocore.config import Config
20
+ from rich.console import Console, ConsoleRenderable
21
+ from rich.highlighter import NullHighlighter
22
+ from rich.progress import Progress
23
+ from rich.text import Text
24
+ from rich.traceback import Traceback
25
+
26
+ from .aliases import PathOrStr
27
+ from .exceptions import (
28
+ OLMoCliError,
29
+ OLMoEnvironmentError,
30
+ OLMoError,
31
+ OLMoNetworkError,
32
+ OLMoThreadError,
33
+ )
34
+ from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
35
+
36
+ try:
37
+ from functools import cache
38
+ except ImportError:
39
+ from functools import lru_cache as cache
40
+
41
+
42
+ class StrEnum(str, Enum):
43
+ """
44
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
45
+ We include this here for compatibility with older version of Python.
46
+ """
47
+
48
+ def __str__(self) -> str:
49
+ return self.value
50
+
51
+ def __repr__(self) -> str:
52
+ return f"'{str(self)}'"
53
+
54
+
55
+ _log_extra_fields: Dict[str, Any] = {}
56
+ log = logging.getLogger(__name__)
57
+
58
+
59
+ class LogFilterType(StrEnum):
60
+ rank0_only = "rank0_only"
61
+ local_rank0_only = "local_rank0_only"
62
+ all_ranks = "all_ranks"
63
+
64
+
65
+ def log_extra_field(field_name: str, field_value: Any) -> None:
66
+ global _log_extra_fields
67
+ if field_value is None:
68
+ if field_name in _log_extra_fields:
69
+ del _log_extra_fields[field_name]
70
+ else:
71
+ _log_extra_fields[field_name] = field_value
72
+
73
+
74
+ def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None:
75
+ """
76
+ :param rank0_only: INFO and below messages will only be emitted on the rank0 process.
77
+ """
78
+ log_extra_field("hostname", socket.gethostname())
79
+ if is_distributed():
80
+ log_extra_field("node_rank", get_node_rank())
81
+ log_extra_field("local_rank", get_local_rank())
82
+ log_extra_field("global_rank", get_global_rank())
83
+ else:
84
+ log_extra_field("node_rank", 0)
85
+ log_extra_field("local_rank", 0)
86
+ log_extra_field("global_rank", 0)
87
+
88
+ old_log_record_factory = logging.getLogRecordFactory()
89
+
90
+ def log_record_factory(*args, **kwargs) -> logging.LogRecord:
91
+ record = old_log_record_factory(*args, **kwargs)
92
+ for field_name, field_value in _log_extra_fields.items():
93
+ setattr(record, field_name, field_value)
94
+ return record
95
+
96
+ logging.setLogRecordFactory(log_record_factory)
97
+
98
+ handler: logging.Handler
99
+ if (
100
+ os.environ.get("OLMo_NONINTERACTIVE", False)
101
+ or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive"
102
+ or not sys.stdout.isatty()
103
+ ):
104
+ handler = logging.StreamHandler(sys.stdout)
105
+ formatter = logging.Formatter(
106
+ "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s"
107
+ )
108
+ formatter.default_time_format = "%Y-%m-%d %H:%M:%S"
109
+ formatter.default_msec_format = "%s.%03d"
110
+ handler.setFormatter(formatter)
111
+ else:
112
+ handler = RichHandler()
113
+
114
+ def rank0_filter(record: logging.LogRecord) -> int:
115
+ if record.levelno > logging.INFO:
116
+ return 1
117
+ if getattr(record, "global_rank", 0) == 0:
118
+ return 1
119
+ else:
120
+ return 0
121
+
122
+ def local_rank0_filter(record: logging.LogRecord) -> int:
123
+ if record.levelno > logging.INFO:
124
+ return 1
125
+ if getattr(record, "local_rank", 0) == 0:
126
+ return 1
127
+ else:
128
+ return 0
129
+
130
+ if log_filter_type == LogFilterType.rank0_only:
131
+ filter = rank0_filter
132
+ elif log_filter_type == LogFilterType.local_rank0_only:
133
+ filter = local_rank0_filter # type: ignore
134
+ elif log_filter_type == LogFilterType.all_ranks:
135
+ filter = None
136
+ else:
137
+ raise ValueError(log_filter_type)
138
+
139
+ if filter is not None:
140
+ handler.addFilter(filter) # type: ignore
141
+ logging.basicConfig(handlers=[handler], level=logging.INFO)
142
+
143
+ logging.captureWarnings(True)
144
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
145
+
146
+
147
+ def excepthook(exctype, value, traceback):
148
+ """
149
+ Used to patch `sys.excepthook` in order to log exceptions.
150
+ """
151
+ if issubclass(exctype, KeyboardInterrupt):
152
+ sys.__excepthook__(exctype, value, traceback)
153
+ elif issubclass(exctype, OLMoCliError):
154
+ rich.get_console().print(f"[yellow]{value}[/]", highlight=False)
155
+ elif issubclass(exctype, OLMoError):
156
+ rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
157
+ else:
158
+ log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))
159
+
160
+
161
+ def install_excepthook():
162
+ sys.excepthook = excepthook
163
+
164
+
165
+ def filter_warnings():
166
+ # Filter internal deprecation warnings from torch
167
+ warnings.filterwarnings(
168
+ action="ignore",
169
+ category=UserWarning,
170
+ message="torch.distributed.*_base is a private function and will be deprecated.*",
171
+ )
172
+ warnings.filterwarnings(
173
+ action="ignore",
174
+ category=UserWarning,
175
+ message="TypedStorage is deprecated.*",
176
+ )
177
+ warnings.filterwarnings(
178
+ action="ignore",
179
+ category=UserWarning,
180
+ message="Please use DTensor instead.*",
181
+ )
182
+ # Torchvision warnings. We don't actually use torchvision.
183
+ warnings.filterwarnings(
184
+ action="ignore",
185
+ message="failed to load.*",
186
+ module="torchvision.io.image",
187
+ )
188
+
189
+
190
+ def set_env_variables():
191
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
192
+
193
+
194
+ def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None):
195
+ if log_filter_type is None:
196
+ log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only"))
197
+ rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True)
198
+ setup_logging(log_filter_type=log_filter_type)
199
+ install_excepthook()
200
+ filter_warnings()
201
+ set_env_variables()
202
+
203
+
204
+ def clean_opt(arg: str) -> str:
205
+ if "=" not in arg:
206
+ arg = f"{arg}=True"
207
+ name, val = arg.split("=", 1)
208
+ name = name.strip("-").replace("-", "_")
209
+ return f"{name}={val}"
210
+
211
+
212
+ class RichHandler(logging.Handler):
213
+ """
214
+ A simplified version of rich.logging.RichHandler from
215
+ https://github.com/Textualize/rich/blob/master/rich/logging.py
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ *,
221
+ level: Union[int, str] = logging.NOTSET,
222
+ console: Optional[Console] = None,
223
+ markup: bool = False,
224
+ ) -> None:
225
+ super().__init__(level=level)
226
+ self.console = console or rich.get_console()
227
+ self.highlighter = NullHighlighter()
228
+ self.markup = markup
229
+
230
+ def emit(self, record: logging.LogRecord) -> None:
231
+ try:
232
+ if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"):
233
+ self.console.print(record.msg)
234
+ else:
235
+ msg: Any = record.msg
236
+ if isinstance(record.msg, str):
237
+ msg = self.render_message(record=record, message=record.getMessage())
238
+ renderables = [
239
+ self.get_time_text(record),
240
+ self.get_level_text(record),
241
+ self.get_location_text(record),
242
+ msg,
243
+ ]
244
+ if record.exc_info is not None:
245
+ tb = Traceback.from_exception(*record.exc_info) # type: ignore
246
+ renderables.append(tb)
247
+ self.console.print(*renderables)
248
+ except Exception:
249
+ self.handleError(record)
250
+
251
+ def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable:
252
+ use_markup = getattr(record, "markup", self.markup)
253
+ message_text = Text.from_markup(message) if use_markup else Text(message)
254
+
255
+ highlighter = getattr(record, "highlighter", self.highlighter)
256
+ if highlighter:
257
+ message_text = highlighter(message_text)
258
+
259
+ return message_text
260
+
261
+ def get_time_text(self, record: logging.LogRecord) -> Text:
262
+ log_time = datetime.fromtimestamp(record.created)
263
+ time_str = log_time.strftime("[%Y-%m-%d %X]")
264
+ return Text(time_str, style="log.time", end=" ")
265
+
266
+ def get_level_text(self, record: logging.LogRecord) -> Text:
267
+ level_name = record.levelname
268
+ level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}")
269
+ level_text.style = "log.level"
270
+ level_text.end = " "
271
+ return level_text
272
+
273
+ def get_location_text(self, record: logging.LogRecord) -> Text:
274
+ name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root"
275
+ text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore
276
+ return Text(text, style="log.path")
277
+
278
+
279
+ def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0):
280
+ """Wait for the condition function to return True."""
281
+ start_time = time.monotonic()
282
+ while not condition():
283
+ time.sleep(0.5)
284
+ if time.monotonic() - start_time > timeout:
285
+ raise TimeoutError(f"{description} timed out")
286
+
287
+
288
+ def is_url(path: PathOrStr) -> bool:
289
+ return re.match(r"[a-z0-9]+://.*", str(path)) is not None
290
+
291
+
292
+ def dir_is_empty(dir: PathOrStr) -> bool:
293
+ dir = Path(dir)
294
+ if not dir.is_dir():
295
+ return True
296
+ try:
297
+ next(dir.glob("*"))
298
+ return False
299
+ except StopIteration:
300
+ return True
301
+
302
+
303
+ def get_progress_bar() -> Progress:
304
+ from cached_path import get_download_progress
305
+
306
+ return get_download_progress()
307
+
308
+
309
+ def resource_path(
310
+ folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None
311
+ ) -> Path:
312
+ if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
313
+ log.info(f"Found local cache of {fname} at {local_path}")
314
+ return local_path
315
+ else:
316
+ from cached_path import cached_path
317
+
318
+ return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress)
319
+
320
+
321
+ def file_size(path: PathOrStr) -> int:
322
+ """
323
+ Get the size of a local or remote file in bytes.
324
+ """
325
+ if is_url(path):
326
+ from urllib.parse import urlparse
327
+
328
+ parsed = urlparse(str(path))
329
+ if parsed.scheme == "gs":
330
+ return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
331
+ elif parsed.scheme in ("s3", "r2"):
332
+ return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
333
+ elif parsed.scheme in ("http", "https"):
334
+ return _http_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
335
+ elif parsed.scheme == "file":
336
+ return file_size(str(path).replace("file://", "", 1))
337
+ else:
338
+ raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
339
+ else:
340
+ return os.stat(path).st_size
341
+
342
+
343
+ def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
344
+ """Upload source file to a target location on GCS or S3."""
345
+ from urllib.parse import urlparse
346
+
347
+ source = Path(source)
348
+ assert source.is_file()
349
+ parsed = urlparse(target)
350
+ if parsed.scheme == "gs":
351
+ _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
352
+ elif parsed.scheme in ("s3", "r2"):
353
+ _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
354
+ else:
355
+ raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
356
+
357
+
358
+ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
359
+ if is_url(source):
360
+ from urllib.parse import urlparse
361
+
362
+ parsed = urlparse(str(source))
363
+ if parsed.scheme == "gs":
364
+ return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
365
+ elif parsed.scheme in ("s3", "r2"):
366
+ return _s3_get_bytes_range(
367
+ parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
368
+ )
369
+ elif parsed.scheme in ("http", "https"):
370
+ return _http_get_bytes_range(
371
+ parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
372
+ )
373
+ elif parsed.scheme == "file":
374
+ return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
375
+ else:
376
+ raise NotImplementedError(f"get bytes range not implemented for '{parsed.scheme}' files")
377
+ else:
378
+ with open(source, "rb") as f:
379
+ f.seek(bytes_start)
380
+ return f.read(num_bytes)
381
+
382
+
383
+ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
384
+ if is_url(dir):
385
+ from urllib.parse import urlparse
386
+
387
+ parsed = urlparse(str(dir))
388
+ if parsed.scheme == "gs":
389
+ raise NotImplementedError
390
+ elif parsed.scheme in ("s3", "r2"):
391
+ return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
392
+ elif parsed.scheme == "file":
393
+ return find_latest_checkpoint(str(dir).replace("file://", "", 1))
394
+ else:
395
+ raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
396
+ else:
397
+ latest_step = 0
398
+ latest_checkpoint: Optional[Path] = None
399
+ for path in Path(dir).glob("step*"):
400
+ if path.is_dir():
401
+ try:
402
+ step = int(path.name.replace("step", "").replace("-unsharded", ""))
403
+ except ValueError:
404
+ continue
405
+ # We prioritize sharded checkpoints over unsharded checkpoints.
406
+ if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
407
+ latest_step = step
408
+ latest_checkpoint = path
409
+ return latest_checkpoint
410
+
411
+
412
+ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
413
+ from google.cloud import storage as gcs
414
+
415
+ storage_client = gcs.Client()
416
+ bucket = storage_client.bucket(bucket_name)
417
+ blob = bucket.blob(key)
418
+ if not save_overwrite and blob.exists():
419
+ raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
420
+ blob.upload_from_filename(source)
421
+
422
+
423
+ def _gcs_file_size(bucket_name: str, key: str) -> int:
424
+ from google.api_core.exceptions import NotFound
425
+ from google.cloud import storage as gcs
426
+
427
+ storage_client = gcs.Client()
428
+ bucket = storage_client.bucket(bucket_name)
429
+ blob = bucket.blob(key)
430
+ try:
431
+ blob.reload()
432
+ except NotFound:
433
+ raise FileNotFoundError(f"gs://{bucket_name}/{key}")
434
+ assert blob.size is not None
435
+ return blob.size
436
+
437
+
438
+ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
439
+ from google.api_core.exceptions import NotFound
440
+ from google.cloud import storage as gcs
441
+
442
+ storage_client = gcs.Client()
443
+ bucket = storage_client.bucket(bucket_name)
444
+ blob = bucket.blob(key)
445
+ try:
446
+ blob.reload()
447
+ except NotFound:
448
+ raise FileNotFoundError(f"gs://{bucket_name}/{key}")
449
+ return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1)
450
+
451
+
452
+ def _get_s3_profile_name(scheme: str) -> Optional[str]:
453
+ if scheme == "s3":
454
+ # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
455
+ return os.environ.get("S3_PROFILE")
456
+ if scheme == "r2":
457
+ profile_name = os.environ.get("R2_PROFILE")
458
+ if profile_name is None:
459
+ raise OLMoEnvironmentError(
460
+ "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
461
+ )
462
+
463
+ return profile_name
464
+
465
+ raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")
466
+
467
+
468
+ def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
469
+ if scheme == "s3":
470
+ return None
471
+ if scheme == "r2":
472
+ r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
473
+ if r2_endpoint_url is None:
474
+ raise OLMoEnvironmentError(
475
+ "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
476
+ )
477
+
478
+ return r2_endpoint_url
479
+
480
+ raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
481
+
482
+
483
+ @cache
484
+ def _get_s3_client(scheme: str):
485
+ session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
486
+ return session.client(
487
+ "s3",
488
+ endpoint_url=_get_s3_endpoint_url(scheme),
489
+ config=Config(retries={"max_attempts": 10, "mode": "standard"}),
490
+ use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
491
+ )
492
+
493
+
494
+ def _wait_before_retry(attempt: int):
495
+ time.sleep(min(0.5 * 2**attempt, 3.0))
496
+
497
+
498
+ def _s3_upload(
499
+ source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3
500
+ ):
501
+ err: Optional[Exception] = None
502
+ if not save_overwrite:
503
+ for attempt in range(1, max_attempts + 1):
504
+ try:
505
+ _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
506
+ raise FileExistsError(
507
+ f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
508
+ )
509
+ except boto_exceptions.ClientError as e:
510
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
511
+ err = None
512
+ break
513
+ err = e
514
+
515
+ if attempt < max_attempts:
516
+ log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
517
+ _wait_before_retry(attempt)
518
+
519
+ if err is not None:
520
+ raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err
521
+
522
+ try:
523
+ _get_s3_client(scheme).upload_file(source, bucket_name, key)
524
+ except boto_exceptions.ClientError as e:
525
+ raise OLMoNetworkError(f"Failed to upload to {scheme}") from e
526
+
527
+
528
+ def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
529
+ err: Optional[Exception] = None
530
+ for attempt in range(1, max_attempts + 1):
531
+ try:
532
+ return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
533
+ except boto_exceptions.ClientError as e:
534
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
535
+ raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
536
+ err = e
537
+
538
+ if attempt < max_attempts:
539
+ log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
540
+ _wait_before_retry(attempt)
541
+
542
+ raise OLMoNetworkError(f"Failed to get {scheme} file size") from err
543
+
544
+
545
+ def _s3_get_bytes_range(
546
+ scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
547
+ ) -> bytes:
548
+ err: Optional[Exception] = None
549
+ for attempt in range(1, max_attempts + 1):
550
+ try:
551
+ return (
552
+ _get_s3_client(scheme)
553
+ .get_object(
554
+ Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
555
+ )["Body"]
556
+ .read()
557
+ )
558
+ except boto_exceptions.ClientError as e:
559
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
560
+ raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
561
+ err = e
562
+ except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
563
+ # ResponseStreamingError (subclass of HTTPClientError) can happen as
564
+ # a result of a failed read from the stream (http.client.IncompleteRead).
565
+ # Retrying can help in this case.
566
+ err = e
567
+
568
+ if attempt < max_attempts:
569
+ log.warning(
570
+ "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
571
+ )
572
+ _wait_before_retry(attempt)
573
+
574
+ # When torch's DataLoader intercepts exceptions, it may try to re-raise them
575
+ # by recalling their constructor with a single message arg. Torch has some
576
+ # logic to deal with the absence of a single-parameter constructor, but it
577
+ # doesn't gracefully handle other possible failures in calling such a constructor
578
+ # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
579
+ # in us losing the true exception info. To avoid this, we change the exception
580
+ # to a type that has a single-parameter constructor.
581
+ raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err
582
+
583
+
584
+ def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]:
585
+ if not prefix.endswith("/"):
586
+ prefix = f"{prefix}/"
587
+ response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
588
+ assert not response["IsTruncated"] # need to handle this if it happens
589
+ latest_step = 0
590
+ latest_checkpoint: Optional[str] = None
591
+ for item in response["CommonPrefixes"]:
592
+ prefix = item["Prefix"].strip("/")
593
+ checkpoint_name = os.path.split(prefix)[-1]
594
+ if not checkpoint_name.startswith("step"):
595
+ continue
596
+ try:
597
+ step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
598
+ except ValueError:
599
+ continue
600
+ # Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete
601
+ # (upload might have have failed part way through).
602
+ try:
603
+ _s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml")
604
+ except FileNotFoundError:
605
+ continue
606
+ # We prioritize sharded checkpoints over unsharded ones.
607
+ if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
608
+ latest_step = step
609
+ latest_checkpoint = f"{scheme}://ai2-llm/{prefix}"
610
+ return latest_checkpoint
611
+
612
+
613
+ def _http_file_size(scheme: str, host_name: str, path: str) -> int:
614
+ import requests
615
+
616
+ response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True)
617
+ return int(response.headers.get("content-length"))
618
+
619
+
620
+ def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes:
621
+ import requests
622
+
623
+ response = requests.get(
624
+ f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"}
625
+ )
626
+ result = response.content
627
+ assert (
628
+ len(result) == num_bytes
629
+ ), f"expected {num_bytes} bytes, got {len(result)}" # Some web servers silently ignore range requests and send everything
630
+ return result
631
+
632
+
633
+ def default_thread_count() -> int:
634
+ return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))
635
+
636
+
637
+ def pass_through_fn(fn, *args, **kwargs):
638
+ return fn(*args, **kwargs)
639
+
640
+
641
+ def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None):
642
+ q: Queue = Queue(maxsize=maxsize)
643
+
644
+ sentinel = object()
645
+
646
+ def fill_queue():
647
+ try:
648
+ for value in g:
649
+ q.put(value)
650
+ except Exception as e:
651
+ q.put(e)
652
+ finally:
653
+ q.put(sentinel)
654
+
655
+ thread_name = thread_name or repr(g)
656
+ thread = Thread(name=thread_name, target=fill_queue, daemon=True)
657
+ thread.start()
658
+
659
+ for x in iter(q.get, sentinel):
660
+ if isinstance(x, Exception):
661
+ raise OLMoThreadError(f"generator thread {thread_name} failed") from x
662
+ else:
663
+ yield x
664
+
665
+
666
+ def roundrobin(*iterables):
667
+ """
668
+ Call the given iterables in a round-robin fashion. For example:
669
+ ``roundrobin('ABC', 'D', 'EF') --> A D E B F C``
670
+ """
671
+ # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes
672
+ num_active = len(iterables)
673
+ nexts = cycle(iter(it).__next__ for it in iterables)
674
+ while num_active:
675
+ try:
676
+ for next in nexts:
677
+ yield next()
678
+ except StopIteration:
679
+ # Remove the iterator we just exhausted from the cycle.
680
+ num_active -= 1
681
+ nexts = cycle(islice(nexts, num_active))
model/version.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _MAJOR = "0"
2
+ _MINOR = "3"
3
+ # On main and in a nightly release the patch should be one ahead of the last
4
+ # released build.
5
+ _PATCH = "0"
6
+ # This is mainly for nightly builds which have the suffix ".dev$DATE". See
7
+ # https://semver.org/#is-v123-a-semantic-version for the semantics.
8
+ _SUFFIX = ""
9
+
10
+ VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
11
+ VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
optim.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a93dfdd0bbd50edd0b30fba9adea180780e7010e4ba0b40a79034fdb48630a1f
3
+ size 302102214
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d27b5b4bdf76917ea2b9366e0db46c302d8d2441ba866ae55b9ecffd5c2bc034
3
+ size 151047623
train.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:591449343a28aa3f7c41b042d7416e76e1fa3d304d7e0037c64ad3169abde7e0
3
+ size 14988