zehui127 commited on
Commit
278ed52
·
verified ·
1 Parent(s): f7c5681

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -1231
train.py DELETED
@@ -1,1231 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import cProfile
4
- import gc
5
- import logging
6
- import math
7
- import os
8
- import random
9
- import shutil
10
- import time
11
- from collections import deque
12
- from dataclasses import dataclass, field
13
- from itertools import islice
14
- from pathlib import Path
15
- from pstats import SortKey
16
- from typing import Any, Callable, Deque, Dict, List, Optional, TextIO, Tuple
17
-
18
- import numpy as np
19
- import torch
20
- import torch.distributed as dist
21
- import torch.nn.functional as F
22
- import wandb
23
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
24
- from torch.utils.data import DataLoader
25
-
26
- from .aliases import PathOrStr
27
- from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
28
- from .config import (
29
- CheckpointType,
30
- SchedulerUnits,
31
- ShardedCheckpointerType,
32
- SpeedMonitorConfig,
33
- TrainConfig,
34
- )
35
- from .data import IterableDataset
36
- from .eval import Evaluator
37
- from .exceptions import OLMoConfigurationError
38
- from .model import OLMo
39
- from .optim import Optimizer, Scheduler
40
- from .torch_util import (
41
- barrier,
42
- gc_cuda,
43
- get_fs_local_rank,
44
- get_global_rank,
45
- get_world_size,
46
- move_to_device,
47
- peak_gpu_memory,
48
- synchronize_flag,
49
- synchronize_value,
50
- )
51
- from .util import upload
52
-
53
- __all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]
54
-
55
- log = logging.getLogger(__name__)
56
-
57
-
58
- @dataclass
59
- class SpeedMonitor:
60
- cfg: SpeedMonitorConfig
61
- start_times: Deque[float] = field(default_factory=lambda: deque([]))
62
- global_total_tokens: int = 0
63
- device_interval_tokens: Deque[int] = field(default_factory=lambda: deque([]))
64
-
65
- def batch_start(self, global_total_tokens: int, device_batch_num_tokens: int, record: bool = True) -> None:
66
- self.global_total_tokens = global_total_tokens
67
- if record:
68
- if len(self.start_times) >= self.cfg.window_size:
69
- self.start_times.popleft()
70
- self.device_interval_tokens.popleft()
71
- self.start_times.append(time.monotonic())
72
- self.device_interval_tokens.append(device_batch_num_tokens)
73
-
74
- def reset(self) -> None:
75
- self.start_times.clear()
76
- self.device_interval_tokens.clear()
77
-
78
- def check(self) -> Dict[str, float]:
79
- metrics: Dict[str, float] = {"throughput/total_tokens": self.global_total_tokens}
80
- if self.start_times:
81
- interval_seconds = time.monotonic() - self.start_times[0]
82
- interval_batches = len(self.start_times)
83
- interval_tokens = sum(self.device_interval_tokens)
84
- metrics["throughput/device/tokens_per_second"] = interval_tokens / interval_seconds
85
- metrics["throughput/device/batches_per_second"] = interval_batches / interval_seconds
86
- return metrics
87
-
88
-
89
- @dataclass
90
- class LRMonitor:
91
- optim: torch.optim.Optimizer
92
-
93
- def check(self) -> Dict[str, float]:
94
- lrs = [group["lr"] for group in self.optim.param_groups]
95
- return {f"optim/learning_rate_group{idx}": lr for idx, lr in enumerate(lrs)}
96
-
97
-
98
- def cross_entropy_loss(
99
- logits, labels, ignore_index: int = -100, reduction: str = "mean", compute_z_loss: bool = False
100
- ):
101
- loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)
102
-
103
- if not compute_z_loss:
104
- return loss, None
105
-
106
- z_squared = logits.logsumexp(-1).pow(2)
107
- if reduction == "mean":
108
- z_squared = (z_squared * (labels != ignore_index)).mean()
109
- elif reduction == "sum":
110
- z_squared = (z_squared * (labels != ignore_index)).sum()
111
-
112
- z_loss = 1e-4 * z_squared
113
-
114
- return loss, z_loss
115
-
116
-
117
- @dataclass
118
- class Trainer:
119
- cfg: TrainConfig
120
- model: OLMo
121
- fsdp_model: FSDP
122
- optim: Optimizer
123
- scheduler: Scheduler
124
- train_loader: DataLoader
125
- device: torch.device
126
- evaluators: List[Evaluator]
127
- epoch: Optional[int] = None
128
- global_step: int = 0
129
- global_train_examples_seen_this_epoch: int = 0
130
- """Tracks the global number of training examples seen in the current epoch for the purpose of restoring
131
- the data loader position on restarts."""
132
- global_train_tokens_seen: int = 0
133
- """Tracks the global total number of tokens trained on."""
134
- checkpoints: List[Path] = field(default_factory=list)
135
- unsharded_checkpoints: List[Path] = field(default_factory=list)
136
- ephemeral_checkpoints: List[Path] = field(default_factory=list)
137
- min_train_loss: float = float("inf")
138
- cur_train_loss: float = float("inf")
139
- indices_file: Optional[TextIO] = None
140
- _start_time: float = 0.0
141
- _gc_init_state: bool = True
142
- loss_fn: Callable[..., torch.Tensor] = field(default_factory=lambda: cross_entropy_loss) # type: ignore
143
- last_sharded_checkpoint_step: Optional[int] = None
144
- last_unsharded_checkpoint_step: Optional[int] = None
145
-
146
- def __post_init__(self):
147
- if self.cfg.fused_loss:
148
- from flash_attn.ops.triton.cross_entropy import ( # type: ignore
149
- cross_entropy_loss,
150
- )
151
-
152
- def fused_loss_fn(
153
- logits, labels, ignore_index: int = -100, reduction: str = "mean", compute_z_loss: bool = False
154
- ):
155
- loss, z_loss = cross_entropy_loss(
156
- logits,
157
- labels,
158
- label_smoothing=0.0,
159
- logit_scale=1.0,
160
- lse_square_scale=0.0,
161
- ignored_index=ignore_index,
162
- inplace_backward=False,
163
- process_group=None,
164
- )
165
-
166
- mask = labels != ignore_index
167
-
168
- if reduction == "mean":
169
- loss = loss.sum() / mask.sum()
170
- elif reduction == "sum":
171
- loss = loss.sum()
172
- else:
173
- loss = loss
174
-
175
- if not compute_z_loss:
176
- return loss, None
177
-
178
- if reduction == "mean":
179
- z_loss = z_loss.sum() / mask.sum()
180
- elif reduction == "sum":
181
- z_loss = z_loss.sum()
182
- else:
183
- z_loss = z_loss
184
-
185
- return loss, z_loss
186
-
187
- self.loss_fn = fused_loss_fn
188
-
189
- @property
190
- def dataset(self) -> IterableDataset:
191
- assert isinstance(self.train_loader.dataset, IterableDataset)
192
- return self.train_loader.dataset
193
-
194
- @property
195
- def tokens_per_batch(self) -> int:
196
- return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length
197
-
198
- @property
199
- def batches_per_epoch(self) -> int:
200
- return self.dataset.total_size // self.cfg.global_train_batch_size
201
-
202
- @property
203
- def max_epochs(self) -> int:
204
- if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
205
- return int(self.cfg.max_duration[:-2].strip())
206
- else:
207
- return 1
208
-
209
- @property
210
- def max_steps(self) -> int:
211
- if isinstance(self.cfg.max_duration, int):
212
- return self.cfg.max_duration
213
- elif isinstance(self.cfg.max_duration, str):
214
- if self.cfg.max_duration.endswith("T"):
215
- # convert to float *first* to handle scientific notation
216
- max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
217
- tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
218
- steps_remaining = tokens_remaining // self.tokens_per_batch
219
- return self.global_step + steps_remaining
220
- elif self.cfg.max_duration.endswith("ep"):
221
- max_epochs = int(self.cfg.max_duration[:-2].strip())
222
- return max_epochs * self.batches_per_epoch
223
- else:
224
- # convert to float *first* to handle scientific notation
225
- return int(float(self.cfg.max_duration))
226
- else:
227
- raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
228
-
229
- @property
230
- def max_tokens(self) -> int:
231
- if isinstance(self.cfg.max_duration, int):
232
- return (
233
- self.global_train_tokens_seen
234
- + max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
235
- )
236
- elif isinstance(self.cfg.max_duration, str):
237
- if self.cfg.max_duration.endswith("T"):
238
- # convert to float *first* to handle scientific notation
239
- return int(float(self.cfg.max_duration[:-1].strip()))
240
- elif self.cfg.max_duration.endswith("ep"):
241
- max_epochs = int(self.cfg.max_duration[:-2].strip())
242
- return max_epochs * self.batches_per_epoch * self.tokens_per_batch
243
- else:
244
- # convert to float *first* to handle scientific notation
245
- return (
246
- self.global_train_tokens_seen
247
- + max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
248
- )
249
- else:
250
- raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
251
-
252
- @property
253
- def scheduler_current(self) -> int:
254
- if self.cfg.scheduler.units == SchedulerUnits.steps:
255
- return self.global_step
256
- elif self.cfg.scheduler.units == SchedulerUnits.tokens:
257
- return self.global_train_tokens_seen
258
- else:
259
- raise NotImplementedError(self.cfg.scheduler.units)
260
-
261
- @property
262
- def scheduler_max(self) -> int:
263
- if self.cfg.scheduler.units == SchedulerUnits.steps:
264
- return self.max_steps
265
- elif self.cfg.scheduler.units == SchedulerUnits.tokens:
266
- return self.max_tokens
267
- else:
268
- raise NotImplementedError(self.cfg.scheduler.units)
269
-
270
- def trainer_state_dict(self) -> Dict[str, Any]:
271
- return {
272
- "epoch": self.epoch,
273
- "global_step": self.global_step,
274
- "global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
275
- "global_train_tokens_seen": self.global_train_tokens_seen,
276
- "world_size": get_world_size(),
277
- "checkpoints": self.checkpoints,
278
- "unsharded_checkpoints": self.unsharded_checkpoints,
279
- "ephemeral_checkpoints": self.ephemeral_checkpoints,
280
- "rng": {
281
- "python": random.getstate(),
282
- "numpy": np.random.get_state(),
283
- "torch": torch.random.get_rng_state(),
284
- "cuda": torch.cuda.get_rng_state(),
285
- },
286
- }
287
-
288
- def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
289
- # Checkpoint paths.
290
- self.checkpoints = [
291
- path
292
- for path in state_dict["checkpoints"]
293
- if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
294
- ]
295
- self.unsharded_checkpoints = [
296
- path
297
- for path in state_dict["unsharded_checkpoints"]
298
- if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
299
- ]
300
- self.ephemeral_checkpoints = [
301
- path
302
- for path in state_dict.get("ephemeral_checkpoints", [])
303
- if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
304
- ]
305
-
306
- # Dataset / dataloader position.
307
- checkpoint_epoch = state_dict.get("epoch", 0)
308
- self.global_step = state_dict["global_step"]
309
- self.global_train_examples_seen_this_epoch = state_dict.get(
310
- "global_train_examples_seen_this_epoch",
311
- state_dict.get( # for backwards compatibility
312
- "global_train_examples_seen",
313
- state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size,
314
- ),
315
- )
316
- self.global_train_tokens_seen = state_dict.get(
317
- "global_train_tokens_seen",
318
- state_dict.get("global_data_step", self.global_step) # for backwards compatibility
319
- * self.cfg.global_train_batch_size
320
- * self.cfg.model.max_sequence_length,
321
- )
322
-
323
- if not self.cfg.restore_dataloader:
324
- self.epoch = 0
325
- self.global_train_tokens_seen = 0
326
- self.global_train_examples_seen_this_epoch = 0
327
- elif self.epoch is None:
328
- self.epoch = checkpoint_epoch
329
- elif checkpoint_epoch != self.epoch:
330
- log.info(f"Starting new epoch (epoch = {self.epoch})")
331
- self.global_train_examples_seen_this_epoch = 0
332
-
333
- if self.cfg.fast_forward_batches:
334
- log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
335
- # Technically we don't "see" these batches that we fast-forward through, but we use
336
- # this variable to update the position of the dataset so we need to include them here.
337
- self.global_train_examples_seen_this_epoch += (
338
- self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
339
- )
340
- # NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
341
- # that variable is meant to track the actual number of tokens trained on.
342
-
343
- if self.global_train_examples_seen_this_epoch > 0:
344
- assert isinstance(self.dataset, IterableDataset)
345
- log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
346
- self.dataset.start_index = self.global_train_examples_seen_this_epoch
347
-
348
- # Reset learning rate and weight decay to the values from the config, not the checkpoint.
349
- log.info("Resetting learning rate...")
350
- new_learning_rate = self.scheduler.get_lr(
351
- self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
352
- )
353
- for group in self.optim.param_groups:
354
- group["lr"] = new_learning_rate
355
- group["initial_lr"] = self.cfg.optimizer.learning_rate
356
- if "weight_decay" in group and group["weight_decay"] > 0.0:
357
- group["weight_decay"] = self.cfg.optimizer.weight_decay
358
-
359
- # RNG states.
360
- if "rng" in state_dict and state_dict.get("world_size", get_world_size()) == get_world_size():
361
- log.info("Restoring RNG states...")
362
- rng_state = state_dict["rng"]
363
- self.restore_rng_state(rng_state)
364
- else:
365
- log.warning(
366
- "Trainer will not restore RNG states since the RNG states in the checkpoint are missing or invalid. "
367
- "This typically happens when restoring from an unsharded checkpoint or a checkpoint that was saved "
368
- "with a different world size. If that's the case you can safely ignore this warning."
369
- )
370
-
371
- def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
372
- random.setstate(rng_state["python"])
373
- np.random.set_state(rng_state["numpy"])
374
- torch.set_rng_state(rng_state["torch"])
375
- torch.cuda.set_rng_state(rng_state["cuda"])
376
-
377
- def _save_checkpoint(
378
- self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
379
- ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
380
- if checkpoint_type == CheckpointType.sharded:
381
- suffix = ""
382
- current_checkpoints = self.checkpoints
383
- link_latest = get_fs_local_rank() == 0
384
- num_checkpoints_to_keep = self.cfg.save_num_checkpoints_to_keep
385
- elif checkpoint_type == CheckpointType.unsharded:
386
- suffix = "-unsharded"
387
- current_checkpoints = self.unsharded_checkpoints
388
- link_latest = get_global_rank() == 0
389
- num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
390
- elif checkpoint_type == CheckpointType.sharded_ephemeral:
391
- suffix = ""
392
- current_checkpoints = self.ephemeral_checkpoints
393
- link_latest = get_fs_local_rank() == 0
394
- num_checkpoints_to_keep = 1
395
- else:
396
- raise NotImplementedError(checkpoint_type)
397
-
398
- # Zero-gradients to avoid gathering them.
399
- self.optim.zero_grad(set_to_none=True)
400
-
401
- # Flush data indices file.
402
- # TODO: upload the indices files?
403
- if self.indices_file is not None:
404
- self.indices_file.flush()
405
-
406
- checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}{suffix}"
407
- remote_checkpoint_dir: Optional[str] = None
408
- if self.cfg.remote_save_folder is not None:
409
- remote_checkpoint_dir = f"{self.cfg.remote_save_folder.rstrip('/')}/{checkpoint_dir.name}"
410
- current_checkpoints.append(checkpoint_dir)
411
-
412
- # Save the checkpoint.
413
- try:
414
- checkpointer.save_checkpoint(
415
- checkpoint_dir,
416
- self.fsdp_model,
417
- self.optim,
418
- self.trainer_state_dict(),
419
- upload_to=remote_checkpoint_dir,
420
- )
421
- except FileExistsError:
422
- raise OLMoConfigurationError(
423
- f"Checkpoint for step {self.global_step} already exists, use --save-overwrite to overwrite it"
424
- )
425
-
426
- if link_latest:
427
- if get_global_rank() == 0:
428
- # Link to 'latest'.
429
- latest_path = Path(self.cfg.save_folder) / f"latest{suffix}"
430
- latest_path.unlink(missing_ok=True)
431
- try:
432
- latest_path.symlink_to(checkpoint_dir.name, target_is_directory=True)
433
- except FileExistsError:
434
- # Same as above, caught when another (file-system) local rank 0 has already made the 'latest' symlink.
435
- # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
436
- # file-systems.
437
- if latest_path.resolve().name != checkpoint_dir.name:
438
- raise
439
-
440
- # Remove old checkpoints.
441
- if num_checkpoints_to_keep > 0:
442
- while len(current_checkpoints) > num_checkpoints_to_keep:
443
- self.remove_checkpoint(0, checkpoint_type)
444
-
445
- barrier()
446
-
447
- if remote_checkpoint_dir is not None:
448
- return remote_checkpoint_dir, checkpoint_dir
449
- else:
450
- return checkpoint_dir, None
451
-
452
- def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
453
- checkpointer = build_sharded_checkpointer(self.cfg)
454
- result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
455
- self.last_sharded_checkpoint_step = self.global_step
456
- return result
457
-
458
- def save_ephemeral_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
459
- checkpointer = build_sharded_checkpointer(self.cfg)
460
- result = self._save_checkpoint(checkpointer, CheckpointType.sharded_ephemeral)
461
- self.last_sharded_checkpoint_step = self.global_step
462
- return result
463
-
464
- def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
465
- oldest_checkpoint = checkpoints.pop(idx)
466
- barrier()
467
- if get_global_rank() == 0 and oldest_checkpoint.is_dir():
468
- shutil.rmtree(oldest_checkpoint, ignore_errors=True)
469
- latest_path = Path(self.cfg.save_folder) / "latest"
470
- if latest_path.resolve() == oldest_checkpoint.resolve():
471
- latest_path.unlink()
472
- barrier()
473
-
474
- def remove_sharded_checkpoint(self, idx: int = 0):
475
- self._remove_sharded_checkpoint(idx, self.checkpoints)
476
-
477
- def remove_ephemeral_checkpoint(self, idx: int = 0):
478
- self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)
479
-
480
- def restore_sharded_checkpoint(
481
- self,
482
- load_path: PathOrStr,
483
- local_cache: Optional[PathOrStr] = None,
484
- *,
485
- load_optimizer_state: bool = True,
486
- load_trainer_state: bool = True,
487
- sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
488
- ):
489
- # Zero-gradients to avoid gathering them.
490
- self.optim.zero_grad(set_to_none=True)
491
- checkpointer = build_sharded_checkpointer(self.cfg, name=sharded_checkpointer)
492
- trainer_state = checkpointer.restore_checkpoint(
493
- load_path,
494
- self.fsdp_model,
495
- self.optim,
496
- local_cache=local_cache,
497
- load_optimizer_state=load_optimizer_state,
498
- )
499
- if load_trainer_state:
500
- self.load_trainer_state_dict(trainer_state)
501
- barrier()
502
-
503
- def save_unsharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
504
- checkpointer = FullCheckpointer(self.cfg)
505
- result = self._save_checkpoint(checkpointer, CheckpointType.unsharded)
506
- self.last_unsharded_checkpoint_step = self.global_step
507
- return result
508
-
509
- def remove_unsharded_checkpoint(self, idx: int = 0):
510
- barrier()
511
- oldest_checkpoint = self.unsharded_checkpoints.pop(idx)
512
- if get_global_rank() == 0 and oldest_checkpoint.is_dir():
513
- shutil.rmtree(oldest_checkpoint, ignore_errors=True)
514
- latest_path = Path(self.cfg.save_folder) / "latest-unsharded"
515
- if latest_path.resolve() == oldest_checkpoint.resolve():
516
- latest_path.unlink()
517
- barrier()
518
-
519
- def restore_unsharded_checkpoint(
520
- self,
521
- load_path: PathOrStr,
522
- local_cache: Optional[PathOrStr] = None,
523
- *,
524
- load_optimizer_state: bool = True,
525
- load_trainer_state: bool = True,
526
- ):
527
- # Zero-gradients to avoid gathering them.
528
- self.optim.zero_grad(set_to_none=True)
529
- checkpointer = FullCheckpointer(self.cfg)
530
- trainer_state = checkpointer.restore_checkpoint(
531
- load_path,
532
- self.fsdp_model,
533
- self.optim,
534
- local_cache=local_cache,
535
- load_optimizer_state=load_optimizer_state,
536
- )
537
- if load_trainer_state:
538
- self.load_trainer_state_dict(trainer_state)
539
- barrier()
540
-
541
- def save_checkpoint(
542
- self, checkpoint_type: CheckpointType = CheckpointType.sharded
543
- ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
544
- result: Tuple[PathOrStr, Optional[PathOrStr]]
545
- if checkpoint_type == CheckpointType.sharded:
546
- result = self.save_sharded_checkpoint()
547
- elif checkpoint_type == CheckpointType.unsharded:
548
- result = self.save_unsharded_checkpoint()
549
- elif checkpoint_type == CheckpointType.sharded_ephemeral:
550
- result = self.save_ephemeral_checkpoint()
551
- else:
552
- raise NotImplementedError(checkpoint_type)
553
-
554
- gc_cuda()
555
- return result
556
-
557
- def restore_checkpoint(
558
- self,
559
- load_path: PathOrStr,
560
- *,
561
- checkpoint_type: Optional[CheckpointType] = None,
562
- local_cache: Optional[PathOrStr] = None,
563
- load_optimizer_state: bool = True,
564
- load_trainer_state: bool = True,
565
- sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
566
- ):
567
- if checkpoint_type == CheckpointType.unsharded or (
568
- checkpoint_type is None and str(load_path).rstrip("/").endswith("-unsharded")
569
- ):
570
- self.restore_unsharded_checkpoint(
571
- load_path,
572
- local_cache=local_cache,
573
- load_optimizer_state=load_optimizer_state,
574
- load_trainer_state=load_trainer_state,
575
- )
576
- elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
577
- self.restore_sharded_checkpoint(
578
- load_path,
579
- local_cache=local_cache,
580
- load_optimizer_state=load_optimizer_state,
581
- load_trainer_state=load_trainer_state,
582
- sharded_checkpointer=sharded_checkpointer,
583
- )
584
- elif checkpoint_type is not None:
585
- raise NotImplementedError(checkpoint_type)
586
-
587
- gc_cuda()
588
-
589
- def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
590
- if checkpoint_type == CheckpointType.sharded:
591
- self.remove_sharded_checkpoint(idx=idx)
592
- elif checkpoint_type == CheckpointType.unsharded:
593
- self.remove_unsharded_checkpoint(idx=idx)
594
- elif checkpoint_type == CheckpointType.sharded_ephemeral:
595
- self.remove_ephemeral_checkpoint(idx=idx)
596
- else:
597
- raise NotImplementedError(checkpoint_type)
598
-
599
- def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
600
- # Labels are just input IDs shifted to the left (first item is ignored).
601
- labels, label_mask, attention_mask = (
602
- batch["input_ids"].clone(),
603
- batch.get("label_mask"),
604
- batch.get("attention_mask"),
605
- )
606
- if label_mask is not None:
607
- labels.masked_fill_(~label_mask, -100)
608
- if attention_mask is not None:
609
- labels.masked_fill_(attention_mask == 0.0, -100)
610
- return labels[..., 1:].contiguous()
611
-
612
- def model_forward(
613
- self, batch: Dict[str, Any], loss_reduction: str = "mean", compute_z_loss: bool = False
614
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
615
- # shape: (batch_size, seq_len, vocab_size)
616
- logits = self.fsdp_model(
617
- input_ids=batch["input_ids"],
618
- attention_mask=batch.get("attention_mask"),
619
- attention_bias=batch.get("attention_bias"),
620
- ).logits
621
- logits_for_loss = logits[..., :-1, :].contiguous()
622
- # shape: (batch_size * seq_len, vocab_size)
623
- logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
624
- # shape: (batch_size, seq_len)
625
- labels = self.get_labels(batch)
626
- # shape: (batch_size * seq_len,)
627
- labels = labels.view(-1)
628
- ce_loss, z_loss = self.loss_fn(
629
- logits_for_loss, labels, ignore_index=-100, reduction=loss_reduction, compute_z_loss=compute_z_loss
630
- )
631
- if loss_reduction == "none":
632
- # Reshape (batch_size * seq_len,) -> (batch_size, seq_len)
633
- ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
634
- if z_loss is not None:
635
- z_loss = z_loss.view(batch["input_ids"].shape[0], -1)
636
- return ce_loss, z_loss, logits
637
-
638
- def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
639
- # Split into micro-batches.
640
- # print(f"Start preparing micro-batches at step {self.global_step}") if get_global_rank() == 0 else None
641
- micro_batches = self.split_batch(batch)
642
-
643
- # In case this helps with memory utilization.
644
- del batch
645
-
646
- ce_batch_loss = torch.tensor(0.0, device=self.device)
647
- z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
648
- # print(f"Start training micro-batches at step {self.global_step}") if get_global_rank() == 0 else None
649
- for micro_batch in micro_batches:
650
- with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
651
- # Run forward pass.)
652
- # print(f"Start forward pass at step {self.global_step}") if get_global_rank() == 0 else None
653
- # print(f"micro_batch['input_ids'].shape: {micro_batch['input_ids'].shape}") if get_global_rank() == 0 else None
654
- # print(f"min micro_batch['input_ids']: {micro_batch['input_ids'].min()}") if get_global_rank() == 0 else None
655
- # print(f"max micro_batch['input_ids']: {micro_batch['input_ids'].max()}") if get_global_rank() == 0 else None
656
- if get_fs_local_rank() == 0 and (self.global_step == 1421 or self.global_step == 1422 or self.global_step == 1423):
657
- # save micro batch input_ids to file, which is a list of integers
658
- with open(f"micro_batch_step{self.global_step}.txt", "w") as f:
659
- for i in range(micro_batch["input_ids"].shape[0]):
660
- f.write(f"{micro_batch['input_ids'][i].tolist()}\n")
661
- ce_loss, z_loss, logits = self.model_forward(
662
- micro_batch, compute_z_loss=self.cfg.softmax_auxiliary_loss
663
- )
664
- # print(f"End micro_batch at step {self.global_step} with ce_loss {ce_loss}, z_loss {z_loss}") if get_global_rank() == 0 else None
665
- ce_loss = ce_loss / len(micro_batches)
666
-
667
- # In case this helps with memory utilization.
668
- del micro_batch
669
-
670
- # Update overall CE batch loss.
671
- ce_batch_loss += ce_loss.detach()
672
-
673
- # Get loss to optimize for.
674
- if self.cfg.softmax_auxiliary_loss:
675
- assert z_loss is not None
676
- assert z_batch_loss is not None
677
- z_loss = z_loss / len(micro_batches)
678
- loss = ce_loss + z_loss
679
-
680
- # Update overall Z batch loss.
681
- z_batch_loss += z_loss.detach()
682
- else:
683
- loss = ce_loss
684
-
685
- del logits
686
- # print(f"---before micro_batch backward at step {self.global_step}") if get_global_rank() == 0 else None
687
- # print(f" loss value: {loss}") if get_global_rank() == 0 else None
688
- # Run backward pass.
689
- loss.backward()
690
- # print(f"---after micro_batch backward at step {self.global_step}") if get_global_rank() == 0 else None
691
- return ce_batch_loss, z_batch_loss
692
-
693
- def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
694
- metrics: Dict[str, float] = {}
695
-
696
- # Write data-indices to file.
697
- if self.indices_file is not None and "index" in batch:
698
- indices = "\t".join(str(int(i)) for i in batch["index"])
699
- self.indices_file.write(f"{self.global_step}\t{indices}\n")
700
-
701
- # Zero-gradients.
702
- self.optim.zero_grad(set_to_none=True)
703
-
704
- # Move tensors to the right device.
705
- batch = move_to_device(batch, self.device)
706
-
707
- # Run forward-backward pass.
708
- ce_batch_loss, z_batch_loss = self.train_batch(batch)
709
- # Collect loss, potentially reducing over all ranks.
710
- if reduce_global_loss:
711
- dist.reduce(ce_batch_loss, 0)
712
- ce_batch_loss.div_(get_world_size())
713
- if z_batch_loss is not None:
714
- dist.reduce(z_batch_loss, 0)
715
- z_batch_loss.div_(get_world_size())
716
-
717
- # Clip gradient norms and collect param/gradient/optim metrics.
718
- should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
719
- optim_metrics = self.optim.clip_grads_and_collect_metrics(
720
- self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
721
- )
722
-
723
- # Adjust the learning rate.
724
- for group in self.optim.param_groups:
725
- # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
726
- # we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
727
- # the corresponding values from `self.cfg`.
728
- group["lr"] = self.scheduler.get_lr(
729
- self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
730
- )
731
- group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
732
- self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
733
- )
734
- group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
735
- self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
736
- )
737
-
738
- # Optimizer step.
739
- self.optim.step()
740
-
741
- # Collect metrics and check for NaN loss.
742
- # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
743
- if torch.isnan(ce_batch_loss):
744
- raise ValueError("nan loss encountered")
745
- if z_batch_loss is not None and torch.isnan(z_batch_loss):
746
- raise ValueError("nan loss encountered")
747
- for key, value in optim_metrics.items():
748
- metrics[f"optim/{key}"] = value.item()
749
- self.cur_train_loss = ce_batch_loss.item()
750
- self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
751
- metrics["train/CrossEntropyLoss"] = self.cur_train_loss
752
- metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
753
- if z_batch_loss is not None:
754
- metrics["train/ZLoss"] = z_batch_loss.item()
755
-
756
- # Maybe collect post-step optimizer-specific metrics.
757
- if should_log_optim_metrics_this_step:
758
- optim_metrics = self.optim.get_post_step_metrics(self.fsdp_model)
759
- for key, value in optim_metrics.items():
760
- metrics[f"optim/{key}"] = value.item()
761
-
762
- return metrics
763
-
764
- def eval_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
765
- with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
766
- ce_loss, _, logits = self.model_forward(batch, loss_reduction="none")
767
- return ce_loss.mean(dim=-1), logits
768
-
769
- def eval_step(self, batch: Dict[str, Any], evaluator: Evaluator) -> None:
770
- # Move tensors to the right device.
771
- batch = move_to_device(batch, self.device)
772
-
773
- # Run forward pass.
774
- with torch.no_grad(): # NOTE: 'torch.inference_mode()' doesn't work with 'torch.compile()'.
775
- ce_loss, logits = self.eval_batch(batch)
776
-
777
- # Update metrics.
778
- evaluator.update_metrics(
779
- batch, ce_loss, logits
780
- ) # batch includes all keys that the downstream evaluation needs
781
-
782
- barrier()
783
-
784
- def split_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
785
- microbatch_size = self.cfg.device_train_microbatch_size
786
- batch_size = batch["input_ids"].shape[0]
787
- if batch_size <= microbatch_size:
788
- return [batch]
789
- else:
790
- micro_batches = {}
791
- for key, value in batch.items():
792
- if isinstance(value, torch.Tensor):
793
- micro_batches[key] = value.split(microbatch_size, dim=0)
794
- elif isinstance(value, list):
795
- micro_batches[key] = [
796
- value[microbatch_size * i : microbatch_size * i + microbatch_size]
797
- for i in range(math.ceil(batch_size / microbatch_size))
798
- ]
799
- else:
800
- raise ValueError(f"unexpected item in batch: '{key}={value}'")
801
- return [
802
- {key: value[i] for key, value in micro_batches.items()} # type: ignore
803
- for i in range(len(micro_batches["input_ids"]))
804
- ]
805
-
806
- def system_metrics(self) -> Dict[str, float]:
807
- metrics = {}
808
- if self.global_step < 3 or self.global_step % 10 == 0:
809
- peak_gpu_mb = peak_gpu_memory()
810
- if peak_gpu_mb is not None:
811
- metrics["System/Peak GPU Memory (MB)"] = peak_gpu_mb
812
- return metrics
813
-
814
- def log_metrics_to_console(self, prefix: str, metrics: Dict[str, float]):
815
- def format_float(value: float) -> str:
816
- if value < 0.0001:
817
- return str(value) # scientific notation
818
- elif value > 1000:
819
- return f"{int(value):,d}"
820
- elif value > 100:
821
- return f"{value:.1f}"
822
- elif value > 10:
823
- return f"{value:.2f}"
824
- elif value > 1:
825
- return f"{value:.3f}"
826
- else:
827
- return f"{value:.4f}"
828
-
829
- log.info(
830
- f"{prefix}\n"
831
- + "\n".join(
832
- [
833
- f" {name}={format_float(value)}"
834
- for name, value in metrics.items()
835
- if not name.startswith("optim/") # there's too many optimizer metrics
836
- ]
837
- )
838
- )
839
-
840
- def should_log_optim_metrics_this_step(self) -> bool:
841
- if self.cfg.wandb is None:
842
- # We only log optimizer-specific metrics to W&B, since there are usually too many metrics
843
- # to log to the console.
844
- return False
845
- optim_log_interval = self.cfg.optimizer.metrics_log_interval
846
- if optim_log_interval is None:
847
- optim_log_interval = self.cfg.wandb.log_interval
848
- else:
849
- optim_log_interval = max(optim_log_interval, self.cfg.wandb.log_interval)
850
- return self.global_step % optim_log_interval == 0
851
-
852
- def should_log_this_step(self) -> bool:
853
- if self.global_step % self.cfg.console_log_interval == 0:
854
- return True
855
- elif self.cfg.wandb is not None and self.global_step % self.cfg.wandb.log_interval == 0:
856
- return True
857
- else:
858
- return False
859
-
860
- def eval(self) -> Dict[str, Any]:
861
- # Zero gradients and set model to 'eval' mode.
862
- self.optim.zero_grad(set_to_none=True)
863
- self.fsdp_model.eval()
864
-
865
- eval_metrics = {}
866
- for evaluator in self.evaluators:
867
- log.info(f"Running evaluation for '{evaluator.label}'...")
868
-
869
- # Reset metrics.
870
- evaluator.reset_metrics()
871
-
872
- # Initialize data loader iterator.
873
- eval_batches = iter(evaluator.eval_loader)
874
-
875
- # Adjust how many batches to evaluate on.
876
- num_eval_batches = (
877
- evaluator.subset_num_batches
878
- if evaluator.subset_num_batches is not None
879
- else self.cfg.eval_subset_num_batches
880
- )
881
- if num_eval_batches > 0:
882
- num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
883
- eval_batches = islice(eval_batches, num_eval_batches)
884
-
885
- # Run model over batches.
886
- for eval_step, eval_batch in enumerate(eval_batches):
887
- self.eval_step(eval_batch, evaluator)
888
-
889
- # Log to console.
890
- if eval_step + 1 == num_eval_batches or (eval_step + 1) % self.cfg.console_log_interval == 0:
891
- log.info(f"[eval_step={eval_step + 1}/{num_eval_batches}]")
892
-
893
- # Get final metrics.
894
- metrics = evaluator.compute_metrics()
895
- eval_metrics.update(metrics)
896
- self.log_metrics_to_console(f"{evaluator.label}", metrics)
897
-
898
- del eval_batches
899
-
900
- return eval_metrics
901
-
902
- def check_if_cancelled(self) -> Tuple[bool, int]:
903
- should_cancel = False
904
- cancel_reason: Optional[str] = None
905
- extra_steps = 0
906
- if get_global_rank() == 0:
907
- if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
908
- # First check if we've reached the training time limit.
909
- should_cancel = True
910
- cancel_reason = "time limit reached"
911
- extra_steps = self.cfg.extra_steps_after_cancel
912
- elif (
913
- self.cfg.early_stopping_factor is not None
914
- and self.global_step > self.cfg.scheduler.t_warmup
915
- and self.cur_train_loss > self.cfg.early_stopping_factor * self.min_train_loss
916
- ):
917
- # Next check if early stopping loss criteria is met.
918
- should_cancel = True
919
- cancel_reason = "early stopping from loss increase"
920
- elif wandb.run is not None and (api_key := os.environ.get("WANDB_API_KEY")) is not None:
921
- # Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag..
922
- # We won't see it in the run object. So we have to use the import/export API to check.
923
- from requests.exceptions import RequestException
924
-
925
- try:
926
- api = wandb.Api(api_key=api_key)
927
- run = api.run(wandb.run.path)
928
- for tag in run.tags or []:
929
- if tag.lower() in {"cancel", "canceled", "cancelled"}:
930
- should_cancel = True
931
- cancel_reason = "Weights & Biases tag"
932
- extra_steps = self.cfg.extra_steps_after_cancel
933
- break
934
- except RequestException:
935
- pass
936
-
937
- run_canceled = synchronize_flag(should_cancel, self.device)
938
- if run_canceled:
939
- extra_steps = synchronize_value(extra_steps, self.device)
940
- if cancel_reason is None:
941
- if extra_steps > 0:
942
- log.warning(f"Run canceled, stopping in {extra_steps} more steps...")
943
- else:
944
- log.warning("Run canceled")
945
- else:
946
- if extra_steps > 0:
947
- log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
948
- else:
949
- log.warning(f"Run canceled due to {cancel_reason}")
950
-
951
- return run_canceled, extra_steps
952
-
953
- def fit(self):
954
- if self.cfg.stop_after is not None:
955
- if self.cfg.stop_at is None:
956
- self.cfg.stop_at = self.global_step + self.cfg.stop_after
957
- else:
958
- self.cfg.stop_at = min(self.cfg.stop_at, self.global_step + self.cfg.stop_after)
959
-
960
- self._start_time = time.time()
961
- self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close.
962
-
963
- # Disable automatic garbage collection, FSDP doesn't work well with it.
964
- if self.cfg.gen1_gc_interval is not None:
965
- gc.disable()
966
-
967
- if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
968
- eval_metrics = self.eval()
969
- if wandb.run is not None:
970
- wandb.log(eval_metrics, step=self.global_step)
971
-
972
- # Set model to 'train' mode.
973
- self.fsdp_model.train()
974
-
975
- # Initialize monitors.
976
- assert self.cfg.device_train_batch_size is not None
977
- speed_monitor = SpeedMonitor(self.cfg.speed_monitor)
978
- lr_monitor = LRMonitor(self.optim)
979
-
980
- # Log system metrics at the start of training.
981
- sys_metrics = self.system_metrics()
982
- if sys_metrics:
983
- self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
984
- if wandb.run is not None:
985
- wandb.log(sys_metrics, step=0)
986
-
987
- # Python Profiler stuff
988
- if self.cfg.python_profiling:
989
- python_profiler = cProfile.Profile()
990
- else:
991
- python_profiler = None
992
-
993
- # PyTorch Profiler stuff
994
- if self.cfg.torch_profiling and get_global_rank() == 0:
995
- from torch.profiler import schedule
996
-
997
- profiling_schedule = schedule(wait=1, warmup=5, active=3, repeat=1)
998
-
999
- def on_trace_ready(p):
1000
- profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
1001
- profiler_output_dir.mkdir(exist_ok=True)
1002
-
1003
- output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32)
1004
- log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}")
1005
- output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
1006
- log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")
1007
-
1008
- p.export_chrome_trace(
1009
- str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
1010
- )
1011
- if self.cfg.remote_save_folder is not None:
1012
- upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler"
1013
- log.info(f"Tracing complete, uploading results to '{upload_folder}'...")
1014
- upload(trace_path, f"{upload_folder}/{trace_path.name}")
1015
-
1016
- from torch.profiler import ProfilerActivity
1017
-
1018
- torch_profiler = torch.profiler.profile(
1019
- activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
1020
- record_shapes=False,
1021
- profile_memory=False,
1022
- with_stack=True,
1023
- schedule=profiling_schedule,
1024
- on_trace_ready=on_trace_ready,
1025
- )
1026
- del profiling_schedule
1027
- else:
1028
- import contextlib
1029
-
1030
- torch_profiler = contextlib.nullcontext()
1031
-
1032
- # Train.
1033
- first_batch: bool = True
1034
- cancel_initiated: bool = False
1035
- stop_at: Optional[int] = self.cfg.stop_at
1036
- save_checkpoints: bool = True
1037
-
1038
- with torch_profiler as p:
1039
- for epoch in range(self.epoch or 0, self.max_epochs):
1040
- for batch in self.train_loader:
1041
- # print(f" >>>>>>>>>>fit start with Global step: {self.global_step} <<<<<<<<<<<<<<<") if get_global_rank()==0 else None
1042
- # Bookkeeping.
1043
- # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
1044
- # batches see the same number of tokens, which should be the case for language model pre-training
1045
- # (at least when drop_last=True).
1046
- # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
1047
- # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
1048
- # fail loudly.
1049
- batch_size, seq_len = batch["input_ids"].shape
1050
- assert seq_len == self.cfg.model.max_sequence_length
1051
- assert batch_size == self.cfg.device_train_batch_size
1052
- global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
1053
- self.global_step += 1
1054
- self.global_train_examples_seen_this_epoch += global_batch_size
1055
- self.global_train_tokens_seen += global_batch_size * seq_len
1056
- speed_monitor.batch_start(
1057
- self.global_train_tokens_seen,
1058
- batch_size * seq_len, # num tokens in batch for this device
1059
- # We start monitoring speed after the first batch since the first
1060
- # batch might be an outlier due to compiling and other initialization overhead.
1061
- record=not first_batch,
1062
- )
1063
-
1064
- should_log_this_step = self.should_log_this_step()
1065
-
1066
- # Run train step on batch.
1067
- metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)
1068
- # print(f" After train step with Global step: {self.global_step}") if get_global_rank()==0 else None
1069
-
1070
- # Maybe collect other metrics.
1071
- if should_log_this_step:
1072
- # Speed metrics.
1073
- metrics.update(speed_monitor.check())
1074
- # System metrics.
1075
- metrics.update(self.system_metrics())
1076
- # Learning rate metrics.
1077
- metrics.update(lr_monitor.check())
1078
-
1079
- # Log metrics to console.
1080
- if self.global_step % self.cfg.console_log_interval == 0:
1081
- if get_global_rank() == 0:
1082
- self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)
1083
- else:
1084
- log.info(f"[step={self.global_step}/{self.max_steps}]")
1085
-
1086
- # Log metrics to W&B.
1087
- if (
1088
- wandb.run is not None
1089
- and self.cfg.wandb is not None
1090
- and self.global_step % self.cfg.wandb.log_interval == 0
1091
- ):
1092
- wandb.log(metrics, step=self.global_step)
1093
-
1094
- # Check if/when run should be canceled.
1095
- if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
1096
- cancel_initiated, extra_steps = self.check_if_cancelled()
1097
- if cancel_initiated:
1098
- stop_at = (
1099
- self.global_step + extra_steps
1100
- if stop_at is None
1101
- else min(self.global_step + extra_steps, stop_at)
1102
- )
1103
-
1104
- # Maybe save sharded checkpoint.
1105
- if save_checkpoints and (
1106
- cancel_initiated
1107
- or (
1108
- self.global_step % self.cfg.save_interval == 0
1109
- and self.cfg.save_num_checkpoints_to_keep != 0
1110
- )
1111
- ):
1112
- log.info("Saving checkpoint...")
1113
- checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
1114
- log.info(f"Checkpoint saved to {checkpoint_path}")
1115
-
1116
- # Remove any ephemeral checkpoints.
1117
- while self.ephemeral_checkpoints:
1118
- self.remove_ephemeral_checkpoint()
1119
-
1120
- # Reset speed monitor so that we don't count the time taken to save checkpoints.
1121
- speed_monitor.reset()
1122
-
1123
- # If the run was just canceled this will be the final checkpoint.
1124
- if cancel_initiated:
1125
- save_checkpoints = False
1126
- elif (
1127
- self.cfg.save_interval_ephemeral is not None
1128
- and self.global_step % self.cfg.save_interval_ephemeral == 0
1129
- ):
1130
- log.info("Saving ephemeral checkpoint...")
1131
- checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
1132
- log.info(f"Checkpoint saved to {checkpoint_path}")
1133
-
1134
- # Reset speed monitor so that we don't count the time taken to save checkpoints.
1135
- speed_monitor.reset()
1136
-
1137
- # Maybe save unsharded checkpoint.
1138
- if (
1139
- save_checkpoints
1140
- and self.cfg.save_interval_unsharded is not None
1141
- and self.global_step % self.cfg.save_interval_unsharded == 0
1142
- and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
1143
- ):
1144
- log.info("Saving unsharded checkpoint...")
1145
- checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
1146
- log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
1147
-
1148
- # Reset speed monitor so that we don't count the time taken to save checkpoints.
1149
- speed_monitor.reset()
1150
-
1151
- # Maybe run evaluations.
1152
- if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
1153
- eval_metrics = self.eval()
1154
-
1155
- # Log metrics to W&B.
1156
- if wandb.run is not None:
1157
- wandb.log(eval_metrics, step=self.global_step)
1158
-
1159
- # Reset speed monitor so that we don't count the time taken to run evaluations.
1160
- speed_monitor.reset()
1161
-
1162
- # Reset model to 'train' mode.
1163
- self.fsdp_model.train()
1164
-
1165
- # End of batch.
1166
- first_batch = False
1167
- if p is not None:
1168
- p.step()
1169
-
1170
- if stop_at is not None and self.global_step >= stop_at:
1171
- break
1172
-
1173
- # Run generation 1 garbage collection.
1174
- if self.cfg.gen1_gc_interval is not None and self.global_step % self.cfg.gen1_gc_interval == 0:
1175
- gc.collect(1)
1176
-
1177
- # Python Profiler stuff
1178
- # We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
1179
- if python_profiler is not None:
1180
- if self.global_step == 5:
1181
- python_profiler.enable()
1182
- elif self.global_step == 8:
1183
- python_profiler.disable()
1184
- python_profiler.print_stats(sort=SortKey.CUMULATIVE)
1185
- python_profiler = None
1186
- else:
1187
- log.info("Training epoch complete")
1188
- self.epoch = epoch + 1
1189
- self.global_train_examples_seen_this_epoch = 0
1190
- if self.epoch < self.max_epochs:
1191
- self.dataset.reshuffle()
1192
- continue
1193
-
1194
- break
1195
-
1196
- # Save final checkpoint.
1197
- if save_checkpoints:
1198
- if (
1199
- self.cfg.save_interval_unsharded is not None
1200
- and self.last_unsharded_checkpoint_step != self.global_step
1201
- ):
1202
- log.info("Saving final unsharded model checkpoint...")
1203
- checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
1204
- log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
1205
- elif (
1206
- self.cfg.save_num_checkpoints_to_keep != 0
1207
- and self.last_sharded_checkpoint_step != self.global_step
1208
- ):
1209
- log.info("Saving final checkpoint...")
1210
- checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
1211
- log.info(f"Checkpoint saved to {checkpoint_path}")
1212
-
1213
- def close(self, exit_code: int = 0) -> None:
1214
- gc_cuda()
1215
-
1216
- if self.indices_file is not None:
1217
- self.indices_file.flush()
1218
- self.indices_file.close()
1219
- if self._gc_init_state:
1220
- gc.enable()
1221
- else:
1222
- gc.disable()
1223
- if wandb.run is not None:
1224
- wandb.finish(exit_code=exit_code, quiet=True)
1225
-
1226
- def __enter__(self) -> Trainer:
1227
- return self
1228
-
1229
- def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1230
- del exc_val, exc_tb
1231
- self.close(0 if exc_type is None else 1)