File size: 54,019 Bytes
ded00f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 |
from __future__ import annotations
import cProfile
import gc
import logging
import math
import os
import random
import shutil
import time
from collections import deque
from dataclasses import dataclass, field
from itertools import islice
from pathlib import Path
from pstats import SortKey
from typing import Any, Callable, Deque, Dict, List, Optional, TextIO, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import wandb
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from .aliases import PathOrStr
from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
from .config import (
CheckpointType,
SchedulerUnits,
ShardedCheckpointerType,
SpeedMonitorConfig,
TrainConfig,
)
from .data import IterableDataset
from .eval import Evaluator
from .exceptions import OLMoConfigurationError
from .model import OLMo
from .optim import Optimizer, Scheduler
from .torch_util import (
barrier,
gc_cuda,
get_fs_local_rank,
get_global_rank,
get_world_size,
move_to_device,
peak_gpu_memory,
synchronize_flag,
synchronize_value,
)
from .util import upload
__all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]
log = logging.getLogger(__name__)
@dataclass
class SpeedMonitor:
cfg: SpeedMonitorConfig
start_times: Deque[float] = field(default_factory=lambda: deque([]))
global_total_tokens: int = 0
device_interval_tokens: Deque[int] = field(default_factory=lambda: deque([]))
def batch_start(self, global_total_tokens: int, device_batch_num_tokens: int, record: bool = True) -> None:
self.global_total_tokens = global_total_tokens
if record:
if len(self.start_times) >= self.cfg.window_size:
self.start_times.popleft()
self.device_interval_tokens.popleft()
self.start_times.append(time.monotonic())
self.device_interval_tokens.append(device_batch_num_tokens)
def reset(self) -> None:
self.start_times.clear()
self.device_interval_tokens.clear()
def check(self) -> Dict[str, float]:
metrics: Dict[str, float] = {"throughput/total_tokens": self.global_total_tokens}
if self.start_times:
interval_seconds = time.monotonic() - self.start_times[0]
interval_batches = len(self.start_times)
interval_tokens = sum(self.device_interval_tokens)
metrics["throughput/device/tokens_per_second"] = interval_tokens / interval_seconds
metrics["throughput/device/batches_per_second"] = interval_batches / interval_seconds
return metrics
@dataclass
class LRMonitor:
optim: torch.optim.Optimizer
def check(self) -> Dict[str, float]:
lrs = [group["lr"] for group in self.optim.param_groups]
return {f"optim/learning_rate_group{idx}": lr for idx, lr in enumerate(lrs)}
def cross_entropy_loss(
logits, labels, ignore_index: int = -100, reduction: str = "mean", compute_z_loss: bool = False
):
loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)
if not compute_z_loss:
return loss, None
z_squared = logits.logsumexp(-1).pow(2)
if reduction == "mean":
z_squared = (z_squared * (labels != ignore_index)).mean()
elif reduction == "sum":
z_squared = (z_squared * (labels != ignore_index)).sum()
z_loss = 1e-4 * z_squared
return loss, z_loss
@dataclass
class Trainer:
cfg: TrainConfig
model: OLMo
fsdp_model: FSDP
optim: Optimizer
scheduler: Scheduler
train_loader: DataLoader
device: torch.device
evaluators: List[Evaluator]
epoch: Optional[int] = None
global_step: int = 0
global_train_examples_seen_this_epoch: int = 0
"""Tracks the global number of training examples seen in the current epoch for the purpose of restoring
the data loader position on restarts."""
global_train_tokens_seen: int = 0
"""Tracks the global total number of tokens trained on."""
checkpoints: List[Path] = field(default_factory=list)
unsharded_checkpoints: List[Path] = field(default_factory=list)
ephemeral_checkpoints: List[Path] = field(default_factory=list)
min_train_loss: float = float("inf")
cur_train_loss: float = float("inf")
indices_file: Optional[TextIO] = None
_start_time: float = 0.0
_gc_init_state: bool = True
loss_fn: Callable[..., torch.Tensor] = field(default_factory=lambda: cross_entropy_loss) # type: ignore
last_sharded_checkpoint_step: Optional[int] = None
last_unsharded_checkpoint_step: Optional[int] = None
def __post_init__(self):
if self.cfg.fused_loss:
from flash_attn.ops.triton.cross_entropy import ( # type: ignore
cross_entropy_loss,
)
def fused_loss_fn(
logits, labels, ignore_index: int = -100, reduction: str = "mean", compute_z_loss: bool = False
):
loss, z_loss = cross_entropy_loss(
logits,
labels,
label_smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
ignored_index=ignore_index,
inplace_backward=False,
process_group=None,
)
mask = labels != ignore_index
if reduction == "mean":
loss = loss.sum() / mask.sum()
elif reduction == "sum":
loss = loss.sum()
else:
loss = loss
if not compute_z_loss:
return loss, None
if reduction == "mean":
z_loss = z_loss.sum() / mask.sum()
elif reduction == "sum":
z_loss = z_loss.sum()
else:
z_loss = z_loss
return loss, z_loss
self.loss_fn = fused_loss_fn
@property
def dataset(self) -> IterableDataset:
assert isinstance(self.train_loader.dataset, IterableDataset)
return self.train_loader.dataset
@property
def tokens_per_batch(self) -> int:
return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length
@property
def batches_per_epoch(self) -> int:
return self.dataset.total_size // self.cfg.global_train_batch_size
@property
def max_epochs(self) -> int:
if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
return int(self.cfg.max_duration[:-2].strip())
else:
return 1
@property
def max_steps(self) -> int:
if isinstance(self.cfg.max_duration, int):
return self.cfg.max_duration
elif isinstance(self.cfg.max_duration, str):
if self.cfg.max_duration.endswith("T"):
# convert to float *first* to handle scientific notation
max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
steps_remaining = tokens_remaining // self.tokens_per_batch
return self.global_step + steps_remaining
elif self.cfg.max_duration.endswith("ep"):
max_epochs = int(self.cfg.max_duration[:-2].strip())
return max_epochs * self.batches_per_epoch
else:
# convert to float *first* to handle scientific notation
return int(float(self.cfg.max_duration))
else:
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
@property
def max_tokens(self) -> int:
if isinstance(self.cfg.max_duration, int):
return (
self.global_train_tokens_seen
+ max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
)
elif isinstance(self.cfg.max_duration, str):
if self.cfg.max_duration.endswith("T"):
# convert to float *first* to handle scientific notation
return int(float(self.cfg.max_duration[:-1].strip()))
elif self.cfg.max_duration.endswith("ep"):
max_epochs = int(self.cfg.max_duration[:-2].strip())
return max_epochs * self.batches_per_epoch * self.tokens_per_batch
else:
# convert to float *first* to handle scientific notation
return (
self.global_train_tokens_seen
+ max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
)
else:
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")
@property
def scheduler_current(self) -> int:
if self.cfg.scheduler.units == SchedulerUnits.steps:
return self.global_step
elif self.cfg.scheduler.units == SchedulerUnits.tokens:
return self.global_train_tokens_seen
else:
raise NotImplementedError(self.cfg.scheduler.units)
@property
def scheduler_max(self) -> int:
if self.cfg.scheduler.units == SchedulerUnits.steps:
return self.max_steps
elif self.cfg.scheduler.units == SchedulerUnits.tokens:
return self.max_tokens
else:
raise NotImplementedError(self.cfg.scheduler.units)
def trainer_state_dict(self) -> Dict[str, Any]:
return {
"epoch": self.epoch,
"global_step": self.global_step,
"global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
"global_train_tokens_seen": self.global_train_tokens_seen,
"world_size": get_world_size(),
"checkpoints": self.checkpoints,
"unsharded_checkpoints": self.unsharded_checkpoints,
"ephemeral_checkpoints": self.ephemeral_checkpoints,
"rng": {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.random.get_rng_state(),
"cuda": torch.cuda.get_rng_state(),
},
}
def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Checkpoint paths.
self.checkpoints = [
path
for path in state_dict["checkpoints"]
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
]
self.unsharded_checkpoints = [
path
for path in state_dict["unsharded_checkpoints"]
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
]
self.ephemeral_checkpoints = [
path
for path in state_dict.get("ephemeral_checkpoints", [])
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
]
# Dataset / dataloader position.
checkpoint_epoch = state_dict.get("epoch", 0)
self.global_step = state_dict["global_step"]
self.global_train_examples_seen_this_epoch = state_dict.get(
"global_train_examples_seen_this_epoch",
state_dict.get( # for backwards compatibility
"global_train_examples_seen",
state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size,
),
)
self.global_train_tokens_seen = state_dict.get(
"global_train_tokens_seen",
state_dict.get("global_data_step", self.global_step) # for backwards compatibility
* self.cfg.global_train_batch_size
* self.cfg.model.max_sequence_length,
)
if not self.cfg.restore_dataloader:
self.epoch = 0
self.global_train_tokens_seen = 0
self.global_train_examples_seen_this_epoch = 0
elif self.epoch is None:
self.epoch = checkpoint_epoch
elif checkpoint_epoch != self.epoch:
log.info(f"Starting new epoch (epoch = {self.epoch})")
self.global_train_examples_seen_this_epoch = 0
if self.cfg.fast_forward_batches:
log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
# Technically we don't "see" these batches that we fast-forward through, but we use
# this variable to update the position of the dataset so we need to include them here.
self.global_train_examples_seen_this_epoch += (
self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
)
# NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
# that variable is meant to track the actual number of tokens trained on.
if self.global_train_examples_seen_this_epoch > 0:
assert isinstance(self.dataset, IterableDataset)
log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
self.dataset.start_index = self.global_train_examples_seen_this_epoch
# Reset learning rate and weight decay to the values from the config, not the checkpoint.
log.info("Resetting learning rate...")
new_learning_rate = self.scheduler.get_lr(
self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
)
for group in self.optim.param_groups:
group["lr"] = new_learning_rate
group["initial_lr"] = self.cfg.optimizer.learning_rate
if "weight_decay" in group and group["weight_decay"] > 0.0:
group["weight_decay"] = self.cfg.optimizer.weight_decay
# RNG states.
if "rng" in state_dict and state_dict.get("world_size", get_world_size()) == get_world_size():
log.info("Restoring RNG states...")
rng_state = state_dict["rng"]
self.restore_rng_state(rng_state)
else:
log.warning(
"Trainer will not restore RNG states since the RNG states in the checkpoint are missing or invalid. "
"This typically happens when restoring from an unsharded checkpoint or a checkpoint that was saved "
"with a different world size. If that's the case you can safely ignore this warning."
)
def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
random.setstate(rng_state["python"])
np.random.set_state(rng_state["numpy"])
torch.set_rng_state(rng_state["torch"])
torch.cuda.set_rng_state(rng_state["cuda"])
def _save_checkpoint(
self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
) -> Tuple[PathOrStr, Optional[PathOrStr]]:
if checkpoint_type == CheckpointType.sharded:
suffix = ""
current_checkpoints = self.checkpoints
link_latest = get_fs_local_rank() == 0
num_checkpoints_to_keep = self.cfg.save_num_checkpoints_to_keep
elif checkpoint_type == CheckpointType.unsharded:
suffix = "-unsharded"
current_checkpoints = self.unsharded_checkpoints
link_latest = get_global_rank() == 0
num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
elif checkpoint_type == CheckpointType.sharded_ephemeral:
suffix = ""
current_checkpoints = self.ephemeral_checkpoints
link_latest = get_fs_local_rank() == 0
num_checkpoints_to_keep = 1
else:
raise NotImplementedError(checkpoint_type)
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)
# Flush data indices file.
# TODO: upload the indices files?
if self.indices_file is not None:
self.indices_file.flush()
checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}{suffix}"
remote_checkpoint_dir: Optional[str] = None
if self.cfg.remote_save_folder is not None:
remote_checkpoint_dir = f"{self.cfg.remote_save_folder.rstrip('/')}/{checkpoint_dir.name}"
current_checkpoints.append(checkpoint_dir)
# Save the checkpoint.
try:
checkpointer.save_checkpoint(
checkpoint_dir,
self.fsdp_model,
self.optim,
self.trainer_state_dict(),
upload_to=remote_checkpoint_dir,
)
except FileExistsError:
raise OLMoConfigurationError(
f"Checkpoint for step {self.global_step} already exists, use --save-overwrite to overwrite it"
)
if link_latest:
if get_global_rank() == 0:
# Link to 'latest'.
latest_path = Path(self.cfg.save_folder) / f"latest{suffix}"
latest_path.unlink(missing_ok=True)
try:
latest_path.symlink_to(checkpoint_dir.name, target_is_directory=True)
except FileExistsError:
# Same as above, caught when another (file-system) local rank 0 has already made the 'latest' symlink.
# This can happen when nodes are saving to a common NFS drive but otherwise have distinct
# file-systems.
if latest_path.resolve().name != checkpoint_dir.name:
raise
# Remove old checkpoints.
if num_checkpoints_to_keep > 0:
while len(current_checkpoints) > num_checkpoints_to_keep:
self.remove_checkpoint(0, checkpoint_type)
barrier()
if remote_checkpoint_dir is not None:
return remote_checkpoint_dir, checkpoint_dir
else:
return checkpoint_dir, None
def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
checkpointer = build_sharded_checkpointer(self.cfg)
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
self.last_sharded_checkpoint_step = self.global_step
return result
def save_ephemeral_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
checkpointer = build_sharded_checkpointer(self.cfg)
result = self._save_checkpoint(checkpointer, CheckpointType.sharded_ephemeral)
self.last_sharded_checkpoint_step = self.global_step
return result
def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
oldest_checkpoint = checkpoints.pop(idx)
barrier()
if get_global_rank() == 0 and oldest_checkpoint.is_dir():
shutil.rmtree(oldest_checkpoint, ignore_errors=True)
latest_path = Path(self.cfg.save_folder) / "latest"
if latest_path.resolve() == oldest_checkpoint.resolve():
latest_path.unlink()
barrier()
def remove_sharded_checkpoint(self, idx: int = 0):
self._remove_sharded_checkpoint(idx, self.checkpoints)
def remove_ephemeral_checkpoint(self, idx: int = 0):
self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)
def restore_sharded_checkpoint(
self,
load_path: PathOrStr,
local_cache: Optional[PathOrStr] = None,
*,
load_optimizer_state: bool = True,
load_trainer_state: bool = True,
sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)
checkpointer = build_sharded_checkpointer(self.cfg, name=sharded_checkpointer)
trainer_state = checkpointer.restore_checkpoint(
load_path,
self.fsdp_model,
self.optim,
local_cache=local_cache,
load_optimizer_state=load_optimizer_state,
)
if load_trainer_state:
self.load_trainer_state_dict(trainer_state)
barrier()
def save_unsharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
checkpointer = FullCheckpointer(self.cfg)
result = self._save_checkpoint(checkpointer, CheckpointType.unsharded)
self.last_unsharded_checkpoint_step = self.global_step
return result
def remove_unsharded_checkpoint(self, idx: int = 0):
barrier()
oldest_checkpoint = self.unsharded_checkpoints.pop(idx)
if get_global_rank() == 0 and oldest_checkpoint.is_dir():
shutil.rmtree(oldest_checkpoint, ignore_errors=True)
latest_path = Path(self.cfg.save_folder) / "latest-unsharded"
if latest_path.resolve() == oldest_checkpoint.resolve():
latest_path.unlink()
barrier()
def restore_unsharded_checkpoint(
self,
load_path: PathOrStr,
local_cache: Optional[PathOrStr] = None,
*,
load_optimizer_state: bool = True,
load_trainer_state: bool = True,
):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)
checkpointer = FullCheckpointer(self.cfg)
trainer_state = checkpointer.restore_checkpoint(
load_path,
self.fsdp_model,
self.optim,
local_cache=local_cache,
load_optimizer_state=load_optimizer_state,
)
if load_trainer_state:
self.load_trainer_state_dict(trainer_state)
barrier()
def save_checkpoint(
self, checkpoint_type: CheckpointType = CheckpointType.sharded
) -> Tuple[PathOrStr, Optional[PathOrStr]]:
result: Tuple[PathOrStr, Optional[PathOrStr]]
if checkpoint_type == CheckpointType.sharded:
result = self.save_sharded_checkpoint()
elif checkpoint_type == CheckpointType.unsharded:
result = self.save_unsharded_checkpoint()
elif checkpoint_type == CheckpointType.sharded_ephemeral:
result = self.save_ephemeral_checkpoint()
else:
raise NotImplementedError(checkpoint_type)
gc_cuda()
return result
def restore_checkpoint(
self,
load_path: PathOrStr,
*,
checkpoint_type: Optional[CheckpointType] = None,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
load_trainer_state: bool = True,
sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
):
if checkpoint_type == CheckpointType.unsharded or (
checkpoint_type is None and str(load_path).rstrip("/").endswith("-unsharded")
):
self.restore_unsharded_checkpoint(
load_path,
local_cache=local_cache,
load_optimizer_state=load_optimizer_state,
load_trainer_state=load_trainer_state,
)
elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
self.restore_sharded_checkpoint(
load_path,
local_cache=local_cache,
load_optimizer_state=load_optimizer_state,
load_trainer_state=load_trainer_state,
sharded_checkpointer=sharded_checkpointer,
)
elif checkpoint_type is not None:
raise NotImplementedError(checkpoint_type)
gc_cuda()
def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
if checkpoint_type == CheckpointType.sharded:
self.remove_sharded_checkpoint(idx=idx)
elif checkpoint_type == CheckpointType.unsharded:
self.remove_unsharded_checkpoint(idx=idx)
elif checkpoint_type == CheckpointType.sharded_ephemeral:
self.remove_ephemeral_checkpoint(idx=idx)
else:
raise NotImplementedError(checkpoint_type)
def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
# Labels are just input IDs shifted to the left (first item is ignored).
labels, label_mask, attention_mask = (
batch["input_ids"].clone(),
batch.get("label_mask"),
batch.get("attention_mask"),
)
if label_mask is not None:
labels.masked_fill_(~label_mask, -100)
if attention_mask is not None:
labels.masked_fill_(attention_mask == 0.0, -100)
return labels[..., 1:].contiguous()
def model_forward(
self, batch: Dict[str, Any], loss_reduction: str = "mean", compute_z_loss: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
# shape: (batch_size, seq_len, vocab_size)
logits = self.fsdp_model(
input_ids=batch["input_ids"],
attention_mask=batch.get("attention_mask"),
attention_bias=batch.get("attention_bias"),
).logits
logits_for_loss = logits[..., :-1, :].contiguous()
# shape: (batch_size * seq_len, vocab_size)
logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
# shape: (batch_size, seq_len)
labels = self.get_labels(batch)
# shape: (batch_size * seq_len,)
labels = labels.view(-1)
ce_loss, z_loss = self.loss_fn(
logits_for_loss, labels, ignore_index=-100, reduction=loss_reduction, compute_z_loss=compute_z_loss
)
if loss_reduction == "none":
# Reshape (batch_size * seq_len,) -> (batch_size, seq_len)
ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
if z_loss is not None:
z_loss = z_loss.view(batch["input_ids"].shape[0], -1)
return ce_loss, z_loss, logits
def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Split into micro-batches.
# print(f"Start preparing micro-batches at step {self.global_step}") if get_global_rank() == 0 else None
micro_batches = self.split_batch(batch)
# In case this helps with memory utilization.
del batch
ce_batch_loss = torch.tensor(0.0, device=self.device)
z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
# print(f"Start training micro-batches at step {self.global_step}") if get_global_rank() == 0 else None
for micro_batch in micro_batches:
with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
# Run forward pass.)
# print(f"Start forward pass at step {self.global_step}") if get_global_rank() == 0 else None
# print(f"micro_batch['input_ids'].shape: {micro_batch['input_ids'].shape}") if get_global_rank() == 0 else None
# print(f"min micro_batch['input_ids']: {micro_batch['input_ids'].min()}") if get_global_rank() == 0 else None
# print(f"max micro_batch['input_ids']: {micro_batch['input_ids'].max()}") if get_global_rank() == 0 else None
if get_fs_local_rank() == 0 and (self.global_step == 1421 or self.global_step == 1422 or self.global_step == 1423):
# save micro batch input_ids to file, which is a list of integers
with open(f"micro_batch_step{self.global_step}.txt", "w") as f:
for i in range(micro_batch["input_ids"].shape[0]):
f.write(f"{micro_batch['input_ids'][i].tolist()}\n")
ce_loss, z_loss, logits = self.model_forward(
micro_batch, compute_z_loss=self.cfg.softmax_auxiliary_loss
)
# print(f"End micro_batch at step {self.global_step} with ce_loss {ce_loss}, z_loss {z_loss}") if get_global_rank() == 0 else None
ce_loss = ce_loss / len(micro_batches)
# In case this helps with memory utilization.
del micro_batch
# Update overall CE batch loss.
ce_batch_loss += ce_loss.detach()
# Get loss to optimize for.
if self.cfg.softmax_auxiliary_loss:
assert z_loss is not None
assert z_batch_loss is not None
z_loss = z_loss / len(micro_batches)
loss = ce_loss + z_loss
# Update overall Z batch loss.
z_batch_loss += z_loss.detach()
else:
loss = ce_loss
del logits
# print(f"---before micro_batch backward at step {self.global_step}") if get_global_rank() == 0 else None
# print(f" loss value: {loss}") if get_global_rank() == 0 else None
# Run backward pass.
loss.backward()
# print(f"---after micro_batch backward at step {self.global_step}") if get_global_rank() == 0 else None
return ce_batch_loss, z_batch_loss
def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
metrics: Dict[str, float] = {}
# Write data-indices to file.
if self.indices_file is not None and "index" in batch:
indices = "\t".join(str(int(i)) for i in batch["index"])
self.indices_file.write(f"{self.global_step}\t{indices}\n")
# Zero-gradients.
self.optim.zero_grad(set_to_none=True)
# Move tensors to the right device.
batch = move_to_device(batch, self.device)
# Run forward-backward pass.
ce_batch_loss, z_batch_loss = self.train_batch(batch)
# Collect loss, potentially reducing over all ranks.
if reduce_global_loss:
dist.reduce(ce_batch_loss, 0)
ce_batch_loss.div_(get_world_size())
if z_batch_loss is not None:
dist.reduce(z_batch_loss, 0)
z_batch_loss.div_(get_world_size())
# Clip gradient norms and collect param/gradient/optim metrics.
should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
optim_metrics = self.optim.clip_grads_and_collect_metrics(
self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
)
# Adjust the learning rate.
for group in self.optim.param_groups:
# TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
# we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
# the corresponding values from `self.cfg`.
group["lr"] = self.scheduler.get_lr(
self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
)
group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
)
group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
)
# Optimizer step.
self.optim.step()
# Collect metrics and check for NaN loss.
# NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
if torch.isnan(ce_batch_loss):
raise ValueError("nan loss encountered")
if z_batch_loss is not None and torch.isnan(z_batch_loss):
raise ValueError("nan loss encountered")
for key, value in optim_metrics.items():
metrics[f"optim/{key}"] = value.item()
self.cur_train_loss = ce_batch_loss.item()
self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
metrics["train/CrossEntropyLoss"] = self.cur_train_loss
metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
if z_batch_loss is not None:
metrics["train/ZLoss"] = z_batch_loss.item()
# Maybe collect post-step optimizer-specific metrics.
if should_log_optim_metrics_this_step:
optim_metrics = self.optim.get_post_step_metrics(self.fsdp_model)
for key, value in optim_metrics.items():
metrics[f"optim/{key}"] = value.item()
return metrics
def eval_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
ce_loss, _, logits = self.model_forward(batch, loss_reduction="none")
return ce_loss.mean(dim=-1), logits
def eval_step(self, batch: Dict[str, Any], evaluator: Evaluator) -> None:
# Move tensors to the right device.
batch = move_to_device(batch, self.device)
# Run forward pass.
with torch.no_grad(): # NOTE: 'torch.inference_mode()' doesn't work with 'torch.compile()'.
ce_loss, logits = self.eval_batch(batch)
# Update metrics.
evaluator.update_metrics(
batch, ce_loss, logits
) # batch includes all keys that the downstream evaluation needs
barrier()
def split_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
microbatch_size = self.cfg.device_train_microbatch_size
batch_size = batch["input_ids"].shape[0]
if batch_size <= microbatch_size:
return [batch]
else:
micro_batches = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
micro_batches[key] = value.split(microbatch_size, dim=0)
elif isinstance(value, list):
micro_batches[key] = [
value[microbatch_size * i : microbatch_size * i + microbatch_size]
for i in range(math.ceil(batch_size / microbatch_size))
]
else:
raise ValueError(f"unexpected item in batch: '{key}={value}'")
return [
{key: value[i] for key, value in micro_batches.items()} # type: ignore
for i in range(len(micro_batches["input_ids"]))
]
def system_metrics(self) -> Dict[str, float]:
metrics = {}
if self.global_step < 3 or self.global_step % 10 == 0:
peak_gpu_mb = peak_gpu_memory()
if peak_gpu_mb is not None:
metrics["System/Peak GPU Memory (MB)"] = peak_gpu_mb
return metrics
def log_metrics_to_console(self, prefix: str, metrics: Dict[str, float]):
def format_float(value: float) -> str:
if value < 0.0001:
return str(value) # scientific notation
elif value > 1000:
return f"{int(value):,d}"
elif value > 100:
return f"{value:.1f}"
elif value > 10:
return f"{value:.2f}"
elif value > 1:
return f"{value:.3f}"
else:
return f"{value:.4f}"
log.info(
f"{prefix}\n"
+ "\n".join(
[
f" {name}={format_float(value)}"
for name, value in metrics.items()
if not name.startswith("optim/") # there's too many optimizer metrics
]
)
)
def should_log_optim_metrics_this_step(self) -> bool:
if self.cfg.wandb is None:
# We only log optimizer-specific metrics to W&B, since there are usually too many metrics
# to log to the console.
return False
optim_log_interval = self.cfg.optimizer.metrics_log_interval
if optim_log_interval is None:
optim_log_interval = self.cfg.wandb.log_interval
else:
optim_log_interval = max(optim_log_interval, self.cfg.wandb.log_interval)
return self.global_step % optim_log_interval == 0
def should_log_this_step(self) -> bool:
if self.global_step % self.cfg.console_log_interval == 0:
return True
elif self.cfg.wandb is not None and self.global_step % self.cfg.wandb.log_interval == 0:
return True
else:
return False
def eval(self) -> Dict[str, Any]:
# Zero gradients and set model to 'eval' mode.
self.optim.zero_grad(set_to_none=True)
self.fsdp_model.eval()
eval_metrics = {}
for evaluator in self.evaluators:
log.info(f"Running evaluation for '{evaluator.label}'...")
# Reset metrics.
evaluator.reset_metrics()
# Initialize data loader iterator.
eval_batches = iter(evaluator.eval_loader)
# Adjust how many batches to evaluate on.
num_eval_batches = (
evaluator.subset_num_batches
if evaluator.subset_num_batches is not None
else self.cfg.eval_subset_num_batches
)
if num_eval_batches > 0:
num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
eval_batches = islice(eval_batches, num_eval_batches)
# Run model over batches.
for eval_step, eval_batch in enumerate(eval_batches):
self.eval_step(eval_batch, evaluator)
# Log to console.
if eval_step + 1 == num_eval_batches or (eval_step + 1) % self.cfg.console_log_interval == 0:
log.info(f"[eval_step={eval_step + 1}/{num_eval_batches}]")
# Get final metrics.
metrics = evaluator.compute_metrics()
eval_metrics.update(metrics)
self.log_metrics_to_console(f"{evaluator.label}", metrics)
del eval_batches
return eval_metrics
def check_if_cancelled(self) -> Tuple[bool, int]:
should_cancel = False
cancel_reason: Optional[str] = None
extra_steps = 0
if get_global_rank() == 0:
if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
# First check if we've reached the training time limit.
should_cancel = True
cancel_reason = "time limit reached"
extra_steps = self.cfg.extra_steps_after_cancel
elif (
self.cfg.early_stopping_factor is not None
and self.global_step > self.cfg.scheduler.t_warmup
and self.cur_train_loss > self.cfg.early_stopping_factor * self.min_train_loss
):
# Next check if early stopping loss criteria is met.
should_cancel = True
cancel_reason = "early stopping from loss increase"
elif wandb.run is not None and (api_key := os.environ.get("WANDB_API_KEY")) is not None:
# Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag..
# We won't see it in the run object. So we have to use the import/export API to check.
from requests.exceptions import RequestException
try:
api = wandb.Api(api_key=api_key)
run = api.run(wandb.run.path)
for tag in run.tags or []:
if tag.lower() in {"cancel", "canceled", "cancelled"}:
should_cancel = True
cancel_reason = "Weights & Biases tag"
extra_steps = self.cfg.extra_steps_after_cancel
break
except RequestException:
pass
run_canceled = synchronize_flag(should_cancel, self.device)
if run_canceled:
extra_steps = synchronize_value(extra_steps, self.device)
if cancel_reason is None:
if extra_steps > 0:
log.warning(f"Run canceled, stopping in {extra_steps} more steps...")
else:
log.warning("Run canceled")
else:
if extra_steps > 0:
log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
else:
log.warning(f"Run canceled due to {cancel_reason}")
return run_canceled, extra_steps
def fit(self):
if self.cfg.stop_after is not None:
if self.cfg.stop_at is None:
self.cfg.stop_at = self.global_step + self.cfg.stop_after
else:
self.cfg.stop_at = min(self.cfg.stop_at, self.global_step + self.cfg.stop_after)
self._start_time = time.time()
self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close.
# Disable automatic garbage collection, FSDP doesn't work well with it.
if self.cfg.gen1_gc_interval is not None:
gc.disable()
if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
eval_metrics = self.eval()
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)
# Set model to 'train' mode.
self.fsdp_model.train()
# Initialize monitors.
assert self.cfg.device_train_batch_size is not None
speed_monitor = SpeedMonitor(self.cfg.speed_monitor)
lr_monitor = LRMonitor(self.optim)
# Log system metrics at the start of training.
sys_metrics = self.system_metrics()
if sys_metrics:
self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
if wandb.run is not None:
wandb.log(sys_metrics, step=0)
# Python Profiler stuff
if self.cfg.python_profiling:
python_profiler = cProfile.Profile()
else:
python_profiler = None
# PyTorch Profiler stuff
if self.cfg.torch_profiling and get_global_rank() == 0:
from torch.profiler import schedule
profiling_schedule = schedule(wait=1, warmup=5, active=3, repeat=1)
def on_trace_ready(p):
profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
profiler_output_dir.mkdir(exist_ok=True)
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32)
log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}")
output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")
p.export_chrome_trace(
str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
)
if self.cfg.remote_save_folder is not None:
upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler"
log.info(f"Tracing complete, uploading results to '{upload_folder}'...")
upload(trace_path, f"{upload_folder}/{trace_path.name}")
from torch.profiler import ProfilerActivity
torch_profiler = torch.profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=False,
profile_memory=False,
with_stack=True,
schedule=profiling_schedule,
on_trace_ready=on_trace_ready,
)
del profiling_schedule
else:
import contextlib
torch_profiler = contextlib.nullcontext()
# Train.
first_batch: bool = True
cancel_initiated: bool = False
stop_at: Optional[int] = self.cfg.stop_at
save_checkpoints: bool = True
with torch_profiler as p:
for epoch in range(self.epoch or 0, self.max_epochs):
for batch in self.train_loader:
# print(f" >>>>>>>>>>fit start with Global step: {self.global_step} <<<<<<<<<<<<<<<") if get_global_rank()==0 else None
# Bookkeeping.
# NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
# batches see the same number of tokens, which should be the case for language model pre-training
# (at least when drop_last=True).
# Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
# overhead. So for now I'm putting these assertions here so if the assumption is violated it will
# fail loudly.
batch_size, seq_len = batch["input_ids"].shape
assert seq_len == self.cfg.model.max_sequence_length
assert batch_size == self.cfg.device_train_batch_size
global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
self.global_step += 1
self.global_train_examples_seen_this_epoch += global_batch_size
self.global_train_tokens_seen += global_batch_size * seq_len
speed_monitor.batch_start(
self.global_train_tokens_seen,
batch_size * seq_len, # num tokens in batch for this device
# We start monitoring speed after the first batch since the first
# batch might be an outlier due to compiling and other initialization overhead.
record=not first_batch,
)
should_log_this_step = self.should_log_this_step()
# Run train step on batch.
metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)
# print(f" After train step with Global step: {self.global_step}") if get_global_rank()==0 else None
# Maybe collect other metrics.
if should_log_this_step:
# Speed metrics.
metrics.update(speed_monitor.check())
# System metrics.
metrics.update(self.system_metrics())
# Learning rate metrics.
metrics.update(lr_monitor.check())
# Log metrics to console.
if self.global_step % self.cfg.console_log_interval == 0:
if get_global_rank() == 0:
self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)
else:
log.info(f"[step={self.global_step}/{self.max_steps}]")
# Log metrics to W&B.
if (
wandb.run is not None
and self.cfg.wandb is not None
and self.global_step % self.cfg.wandb.log_interval == 0
):
wandb.log(metrics, step=self.global_step)
# Check if/when run should be canceled.
if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
cancel_initiated, extra_steps = self.check_if_cancelled()
if cancel_initiated:
stop_at = (
self.global_step + extra_steps
if stop_at is None
else min(self.global_step + extra_steps, stop_at)
)
# Maybe save sharded checkpoint.
if save_checkpoints and (
cancel_initiated
or (
self.global_step % self.cfg.save_interval == 0
and self.cfg.save_num_checkpoints_to_keep != 0
)
):
log.info("Saving checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
log.info(f"Checkpoint saved to {checkpoint_path}")
# Remove any ephemeral checkpoints.
while self.ephemeral_checkpoints:
self.remove_ephemeral_checkpoint()
# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()
# If the run was just canceled this will be the final checkpoint.
if cancel_initiated:
save_checkpoints = False
elif (
self.cfg.save_interval_ephemeral is not None
and self.global_step % self.cfg.save_interval_ephemeral == 0
):
log.info("Saving ephemeral checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
log.info(f"Checkpoint saved to {checkpoint_path}")
# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()
# Maybe save unsharded checkpoint.
if (
save_checkpoints
and self.cfg.save_interval_unsharded is not None
and self.global_step % self.cfg.save_interval_unsharded == 0
and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
):
log.info("Saving unsharded checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()
# Maybe run evaluations.
if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
eval_metrics = self.eval()
# Log metrics to W&B.
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)
# Reset speed monitor so that we don't count the time taken to run evaluations.
speed_monitor.reset()
# Reset model to 'train' mode.
self.fsdp_model.train()
# End of batch.
first_batch = False
if p is not None:
p.step()
if stop_at is not None and self.global_step >= stop_at:
break
# Run generation 1 garbage collection.
if self.cfg.gen1_gc_interval is not None and self.global_step % self.cfg.gen1_gc_interval == 0:
gc.collect(1)
# Python Profiler stuff
# We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
if python_profiler is not None:
if self.global_step == 5:
python_profiler.enable()
elif self.global_step == 8:
python_profiler.disable()
python_profiler.print_stats(sort=SortKey.CUMULATIVE)
python_profiler = None
else:
log.info("Training epoch complete")
self.epoch = epoch + 1
self.global_train_examples_seen_this_epoch = 0
if self.epoch < self.max_epochs:
self.dataset.reshuffle()
continue
break
# Save final checkpoint.
if save_checkpoints:
if (
self.cfg.save_interval_unsharded is not None
and self.last_unsharded_checkpoint_step != self.global_step
):
log.info("Saving final unsharded model checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
elif (
self.cfg.save_num_checkpoints_to_keep != 0
and self.last_sharded_checkpoint_step != self.global_step
):
log.info("Saving final checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
log.info(f"Checkpoint saved to {checkpoint_path}")
def close(self, exit_code: int = 0) -> None:
gc_cuda()
if self.indices_file is not None:
self.indices_file.flush()
self.indices_file.close()
if self._gc_init_state:
gc.enable()
else:
gc.disable()
if wandb.run is not None:
wandb.finish(exit_code=exit_code, quiet=True)
def __enter__(self) -> Trainer:
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
del exc_val, exc_tb
self.close(0 if exc_type is None else 1)
|