zehui127 commited on
Commit
dd2d94a
·
verified ·
1 Parent(s): 6fdc9fc

Delete checkpoint.py

Browse files
Files changed (1) hide show
  1. checkpoint.py +0 -1732
checkpoint.py DELETED
@@ -1,1732 +0,0 @@
1
- import gc
2
- import io
3
- import logging
4
- import pickle
5
- import shutil
6
- import traceback
7
- from abc import ABCMeta, abstractmethod
8
- from collections import defaultdict
9
- from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
10
- from contextlib import contextmanager
11
- from copy import deepcopy
12
- from dataclasses import dataclass, field, replace
13
- from functools import reduce
14
- from multiprocessing import shared_memory
15
- from pathlib import Path
16
- from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast
17
-
18
- import numpy as np
19
- import torch
20
- import torch.distributed.checkpoint as dist_cp
21
- import torch.multiprocessing as mp
22
- from packaging import version
23
- from torch.distributed import _remote_device
24
- from torch.distributed._shard._utils import narrow_tensor_by_index
25
- from torch.distributed._shard.metadata import ShardMetadata
26
- from torch.distributed._shard.sharded_tensor import ShardedTensor
27
- from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
28
- from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
29
- from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
30
- from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
31
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
32
- from torch.distributed.fsdp import StateDictType
33
- from torch.distributed.fsdp.api import (
34
- FullOptimStateDictConfig,
35
- FullStateDictConfig,
36
- ShardedOptimStateDictConfig,
37
- ShardedStateDictConfig,
38
- )
39
- from torch.futures import Future
40
-
41
- try:
42
- from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore
43
- except ModuleNotFoundError:
44
- from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore
45
-
46
- from olmo import util
47
-
48
- from .aliases import PathOrStr
49
- from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
50
- from .exceptions import OLMoCheckpointError
51
- from .optim import Optimizer, fix_optim_state_dict
52
- from .safetensors_util import safetensors_file_to_state_dict
53
- from .torch_util import (
54
- barrier,
55
- gc_cuda,
56
- get_fs_local_rank,
57
- get_global_rank,
58
- get_world_size,
59
- )
60
- from .util import (
61
- _get_s3_client,
62
- default_thread_count,
63
- dir_is_empty,
64
- get_bytes_range,
65
- get_progress_bar,
66
- resource_path,
67
- upload,
68
- wait_for,
69
- )
70
-
71
- __all__ = [
72
- "save_fsdp_model_and_optim_state",
73
- "load_fsdp_model_and_optim_state",
74
- "load_fsdp_optim_state",
75
- "save_state_dict",
76
- "load_state_dict",
77
- "load_model_state",
78
- "RemoteFileSystemWriter",
79
- "RemoteFileSystemReader",
80
- "Checkpointer",
81
- "FullCheckpointer",
82
- "TorchNewStyleShardedCheckpointer",
83
- "TorchLegacyShardedCheckpointer",
84
- "LocalShardedCheckpointer",
85
- "build_sharded_checkpointer",
86
- ]
87
-
88
-
89
- log = logging.getLogger(__name__)
90
-
91
- MODEL_AND_OPTIM_FOLDER = "model_and_optim"
92
-
93
-
94
- def save_fsdp_model_and_optim_state(
95
- checkpoint_dir: PathOrStr,
96
- fsdp_model: FSDP,
97
- optim: Optimizer,
98
- *,
99
- upload_to: Optional[str] = None,
100
- save_overwrite: bool = False,
101
- ):
102
- """
103
- Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
104
- functions. This should be used during distributed training and should be called by all ranks.
105
-
106
- :param checkpoint_dir: The directory to save to.
107
- :param fsdp_model: The FSDP model.
108
- :param optim: The FSDP model's optimizer.
109
- :param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
110
- :param save_overwrite: Overwrite existing files.
111
-
112
- :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
113
- """
114
- checkpoint_dir = Path(checkpoint_dir)
115
- target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
116
- if save_overwrite:
117
- if get_fs_local_rank() == 0:
118
- shutil.rmtree(target_dir, ignore_errors=True)
119
- elif not dir_is_empty(target_dir):
120
- raise FileExistsError(target_dir)
121
- barrier()
122
- if get_fs_local_rank() == 0:
123
- target_dir.mkdir(exist_ok=True, parents=True)
124
- barrier()
125
- with FSDP.state_dict_type(
126
- fsdp_model,
127
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
128
- state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
129
- optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
130
- ):
131
- model_and_optim_state = {
132
- "model": fsdp_model.state_dict(),
133
- "optim": FSDP.optim_state_dict(fsdp_model, optim),
134
- }
135
- dist_cp.save_state_dict(
136
- model_and_optim_state,
137
- RemoteFileSystemWriter(
138
- target_dir,
139
- upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
140
- save_overwrite=save_overwrite,
141
- ),
142
- )
143
-
144
-
145
- def load_fsdp_model_and_optim_state(
146
- checkpoint_dir: PathOrStr,
147
- fsdp_model: FSDP,
148
- optim: Optimizer,
149
- *,
150
- local_cache: Optional[PathOrStr] = None,
151
- load_optimizer_state: bool = True,
152
- ):
153
- """
154
- Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
155
- functions. This should be used during distributed training and should be called by all ranks.
156
-
157
- :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
158
- :param fsdp_model: The FSDP model.
159
- :param optim: The FSDP model's optimizer.
160
- :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
161
- remote "directory" but there might be a cached version of the same artifacts.
162
- :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
163
-
164
- :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
165
- """
166
- load_path = str(checkpoint_dir).rstrip("/")
167
- local_cache = None if local_cache is None else Path(local_cache)
168
- with FSDP.state_dict_type(
169
- fsdp_model,
170
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
171
- state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
172
- optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
173
- ):
174
- # Load the model state dict in place.
175
- log.info("Loading model state...")
176
- model_state = {"model": fsdp_model.state_dict()}
177
- dist_cp.load_state_dict(
178
- model_state,
179
- RemoteFileSystemReader(
180
- f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
181
- local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
182
- ),
183
- )
184
- fsdp_model.load_state_dict(model_state["model"])
185
-
186
- if not load_optimizer_state:
187
- return
188
-
189
- # Load optim state dict in place.
190
- log.info("Loading sharded optimizer state...")
191
- optim_state = load_sharded_optimizer_state_dict(
192
- model_state_dict=model_state["model"],
193
- optimizer_key="optim",
194
- storage_reader=RemoteFileSystemReader(
195
- f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
196
- local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
197
- ),
198
- )
199
- del model_state
200
- gc_cuda()
201
- load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])
202
-
203
-
204
- def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
205
- log.info("Flattening sharded optimizer state...")
206
- # NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
207
- if version.parse(torch.__version__) < version.parse("2.1.0"):
208
- flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
209
- else:
210
- flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
211
- del optim_state
212
- gc.collect()
213
- log.info("Loading flattened optimizer state...")
214
- # Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
215
- # which takes up unnecessary GPU memory.
216
- for state in flattened_osd["state"].values():
217
- for k in state.keys():
218
- v = state[k]
219
- if isinstance(v, torch.Tensor):
220
- state[k] = v.to(device="cpu")
221
- gc_cuda()
222
- optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))
223
-
224
-
225
- def save_state_dict(
226
- checkpoint_dir: PathOrStr,
227
- fname: str,
228
- state_dict: Dict[str, Any],
229
- *,
230
- upload_to: Optional[str] = None,
231
- save_overwrite: bool = False,
232
- synchronize: bool = True,
233
- ):
234
- """
235
- Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
236
- This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
237
- for each rank.
238
-
239
- :param checkpoint_dir: The directory to save to.
240
- :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
241
- :param state_dict: The state dict to save.
242
- :param upload_to: Optional, a remote "directory" to upload the file to.
243
- :param save_overwrite: Overwrite existing files.
244
- :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
245
- this function from a single rank.
246
-
247
- :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
248
- """
249
- checkpoint_dir = Path(checkpoint_dir)
250
- target_path = checkpoint_dir / fname
251
- if save_overwrite:
252
- target_path.unlink(missing_ok=True)
253
- elif target_path.is_file():
254
- raise FileExistsError(target_path)
255
- if synchronize:
256
- barrier()
257
- target_path.parent.mkdir(exist_ok=True, parents=True)
258
- if synchronize:
259
- barrier()
260
- torch.save(state_dict, target_path)
261
- if upload_to is not None:
262
- upload_target = f"{upload_to.rstrip('/')}/{fname}"
263
- log.info(f"Uploading {target_path} to {upload_target}...")
264
- upload(target_path, upload_target, save_overwrite=save_overwrite)
265
-
266
-
267
- def load_state_dict(
268
- checkpoint_dir: PathOrStr,
269
- fname: str,
270
- *,
271
- local_cache: Optional[PathOrStr] = None,
272
- map_location: Optional[str] = None,
273
- ):
274
- """
275
- Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
276
- This can be used during distributed training or not.
277
-
278
- :param checkpoint_dir: A local or remote checkpoint directory.
279
- :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
280
- :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
281
- remote "directory" but there might be a cached version of the same artifacts.
282
-
283
- :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
284
- """
285
- if fname.endswith(".pt"):
286
- # Try safetensors version first.
287
- try:
288
- path = resource_path(
289
- str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
290
- )
291
- return safetensors_file_to_state_dict(path, map_location=map_location)
292
- except FileNotFoundError:
293
- pass
294
-
295
- path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
296
- return torch.load(path, map_location=map_location)
297
-
298
-
299
- def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
300
- """
301
- Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
302
- Note that ``model`` should not be wrapped with FSDP.
303
- """
304
- state_dict = {"model": model.state_dict()}
305
- dist_cp.load_state_dict(
306
- state_dict,
307
- RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
308
- no_dist=True,
309
- )
310
- model.load_state_dict(state_dict["model"])
311
-
312
-
313
- class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
314
- """
315
- A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
316
- directly to a cloud bucket when ``upload_to`` is specified.
317
- """
318
-
319
- def __init__(
320
- self,
321
- path: PathOrStr,
322
- single_file_per_rank: bool = True,
323
- sync_files: bool = True,
324
- thread_count: Optional[int] = None,
325
- per_thread_copy_ahead: int = 10_000_000,
326
- upload_to: Optional[str] = None,
327
- save_overwrite: bool = False,
328
- ) -> None:
329
- if thread_count is not None and thread_count <= 0:
330
- raise ValueError("thread count must be at least 1")
331
- super().__init__(
332
- path,
333
- single_file_per_rank=single_file_per_rank,
334
- sync_files=sync_files,
335
- # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
336
- # returns because uploading big checkpoint files with multiple threads causes
337
- # boto3 to fail in weird ways.
338
- thread_count=thread_count or 1,
339
- per_thread_copy_ahead=per_thread_copy_ahead,
340
- )
341
- self.upload_to = None if upload_to is None else upload_to.rstrip("/")
342
- self.save_overwrite = save_overwrite
343
-
344
- def write_data(
345
- self,
346
- plan: dist_cp.SavePlan,
347
- planner: dist_cp.SavePlanner,
348
- ) -> Future[List[WriteResult]]:
349
- fut = super().write_data(plan, planner)
350
- if self.upload_to is not None:
351
- files_to_upload = set()
352
- for write_result in fut.wait():
353
- files_to_upload.add(write_result.storage_data.relative_path)
354
-
355
- # Create the global S3 client up front to work around a threading issue in boto.
356
- if self.upload_to.startswith("s3://"):
357
- _get_s3_client("s3")
358
- elif self.upload_to.startswith("r2://"):
359
- _get_s3_client("r2")
360
-
361
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
362
- futures = []
363
- for fname in files_to_upload:
364
- source = self.path / fname
365
- target = f"{self.upload_to}/{fname}"
366
- log.info(f"Uploading {source} to {target}...")
367
- futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
368
- for f in as_completed(futures):
369
- try:
370
- f.result()
371
- except BaseException:
372
- # NOTE: we might get an error here that can't be pickled, which causes a different failure
373
- # later when PyTorch tries to reduce that error across ranks. So here we just make
374
- # sure we're raising a simple error type that can be pickled.
375
- raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
376
- return fut
377
-
378
- def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
379
- super().finish(metadata, results)
380
- if self.upload_to is not None:
381
- source = self.path / ".metadata"
382
- target = f"{self.upload_to}/.metadata"
383
- log.info(f"Uploading {source} to {target}...")
384
- upload(source, target, save_overwrite=self.save_overwrite)
385
-
386
-
387
- class RemoteFileSystemReader(dist_cp.StorageReader):
388
- """
389
- A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
390
- that can read data directly from cloud storage as well as a local directory.
391
- """
392
-
393
- def __init__(
394
- self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
395
- ):
396
- super().__init__()
397
- if thread_count is not None and thread_count <= 0:
398
- raise ValueError("thread count must be at least 1")
399
- self.path = str(path).rstrip("/")
400
- self.cache = None if local_cache is None else Path(local_cache)
401
- self.thread_count = thread_count or default_thread_count()
402
- self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
403
- self._metadata: Optional[Metadata] = None
404
-
405
- def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
406
- if self.cache is not None and (path := self.cache / relative_path).is_file():
407
- return get_bytes_range(path, offset, length)
408
- else:
409
- return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
410
-
411
- def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
412
- sinfo = self.storage_data[read_item.storage_index]
413
- content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
414
- return (read_item, content)
415
-
416
- def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
417
- # Create the global S3 client up front to work around a threading issue in boto.
418
- if isinstance(self.path, str):
419
- if self.path.startswith("s3://"):
420
- _get_s3_client("s3")
421
- elif self.path.startswith("r2://"):
422
- _get_s3_client("r2")
423
-
424
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
425
- read_item_content_futures = []
426
- for read_item in plan.items:
427
- read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
428
- read_item_content_results = []
429
- for f in as_completed(read_item_content_futures):
430
- try:
431
- read_item_content_results.append(f.result())
432
- except BaseException:
433
- # NOTE: we might get an error here that can't be pickled, which causes a different failure
434
- # later when PyTorch tries to reduce that error across ranks. So here we just make
435
- # sure we're raising a simple error type that can be pickled.
436
- raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
437
-
438
- # Modified from `FileSystemReader.read_data()`
439
- for read_item, content in read_item_content_results:
440
- bytes = io.BytesIO(content)
441
- bytes.seek(0)
442
- if read_item.type == LoadItemType.BYTE_IO:
443
- planner.load_bytes(read_item, bytes)
444
- else:
445
- tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
446
- tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
447
- target_tensor = planner.resolve_tensor(read_item).detach()
448
-
449
- assert (
450
- target_tensor.size() == tensor.size()
451
- ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
452
- target_tensor.copy_(tensor)
453
- planner.commit_tensor(read_item, target_tensor)
454
-
455
- fut: Future = Future()
456
- fut.set_result(None)
457
- return fut
458
-
459
- def read_metadata(self) -> Metadata:
460
- if self._metadata is None:
461
- with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
462
- self._metadata = pickle.load(metadata_file)
463
- return self._metadata
464
-
465
- def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
466
- del is_coordinator
467
- self.storage_data = metadata.storage_data
468
- assert self.storage_data is not None
469
-
470
- def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
471
- return plan
472
-
473
- def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
474
- return global_plan
475
-
476
-
477
- class Checkpointer(metaclass=ABCMeta):
478
- def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
479
- self.cfg = cfg
480
- self.thread_count = thread_count or default_thread_count()
481
-
482
- @abstractmethod
483
- def save_checkpoint(
484
- self,
485
- dir: PathOrStr,
486
- fsdp_model: FSDP,
487
- optim: Optimizer,
488
- train_state: Dict[str, Any],
489
- *,
490
- upload_to: Optional[str] = None,
491
- ) -> None:
492
- raise NotImplementedError
493
-
494
- @abstractmethod
495
- def restore_checkpoint(
496
- self,
497
- load_path: PathOrStr,
498
- fsdp_model: FSDP,
499
- optim: Optimizer,
500
- *,
501
- local_cache: Optional[PathOrStr] = None,
502
- load_optimizer_state: bool = True,
503
- ) -> Dict[str, Any]:
504
- """
505
- Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
506
- """
507
- raise NotImplementedError
508
-
509
- def unshard_checkpoint(
510
- self,
511
- load_path: PathOrStr,
512
- *,
513
- local_cache: Optional[PathOrStr] = None,
514
- load_optimizer_state: bool = True,
515
- load_trainer_state: bool = True,
516
- device: Optional[torch.device] = None,
517
- ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
518
- """
519
- Unshard a checkpoint.
520
-
521
- Note this is not marked abstract because child classes are not required to implemented this.
522
- """
523
- del load_path, local_cache, load_optimizer_state, load_trainer_state, device
524
- raise NotImplementedError
525
-
526
- @contextmanager
527
- def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
528
- # Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
529
- checkpoint_dir = Path(dir)
530
- if not dir_is_empty(checkpoint_dir):
531
- if self.cfg.save_overwrite:
532
- if get_fs_local_rank() == 0:
533
- shutil.rmtree(checkpoint_dir, ignore_errors=True)
534
- else:
535
- raise FileExistsError(checkpoint_dir)
536
- # No need to mkdir here since we'll directly replace the temporary directory with
537
- # this directory below.
538
- barrier()
539
-
540
- # Prepare temporary directory. We don't have to be as careful here, we can
541
- # just remove it if it already exists.
542
- checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
543
- if get_fs_local_rank() == 0:
544
- shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
545
- checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)
546
-
547
- barrier()
548
-
549
- # Yield temporary directory for `.save_checkpoint()` to use.
550
- yield checkpoint_dir_tmp
551
-
552
- barrier()
553
-
554
- # Finally if all went well replace the temporary directory with the actual
555
- # checkpoint directory.
556
- if get_fs_local_rank() == 0:
557
- # Replace temp directory with target checkpoint directory.
558
- try:
559
- checkpoint_dir_tmp.replace(checkpoint_dir)
560
- except FileNotFoundError:
561
- # Caught when another (file-system) local rank 0 has already replaced the tmp directory.
562
- # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
563
- # file-systems.
564
- if not checkpoint_dir.exists():
565
- raise
566
-
567
- # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
568
- # replacing the temp directory with the final directory from rank 0 might not be immediately
569
- # realized in the file systems of the other ranks.
570
- # So we wait here across all ranks until that final checkpoint directory is visible.
571
- wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)
572
-
573
- barrier()
574
-
575
- def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
576
- if get_global_rank() == 0:
577
- log.info("Saving config...")
578
- self.cfg.save(config_path := Path(dir) / "config.yaml")
579
- if upload_to is not None:
580
- upload_target = f"{upload_to}/config.yaml"
581
- log.info(f"Uploading {config_path} to {upload_target}")
582
- upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)
583
-
584
-
585
- class FullCheckpointer(Checkpointer):
586
- """
587
- A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
588
- """
589
-
590
- def save_checkpoint(
591
- self,
592
- dir: PathOrStr,
593
- fsdp_model: FSDP,
594
- optim: Optimizer,
595
- trainer_state: Dict[str, Any],
596
- *,
597
- upload_to: Optional[str] = None,
598
- ) -> None:
599
- with self._temporary_wd(dir) as checkpoint_dir:
600
- with FSDP.state_dict_type(
601
- fsdp_model,
602
- state_dict_type=StateDictType.FULL_STATE_DICT,
603
- state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
604
- optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
605
- ):
606
- # We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
607
- # First the model state.
608
- model_state_dict = fsdp_model.state_dict()
609
- if get_global_rank() == 0:
610
- log.info("Saving model state...")
611
- save_state_dict(
612
- checkpoint_dir,
613
- "model.pt",
614
- model_state_dict,
615
- upload_to=upload_to,
616
- save_overwrite=self.cfg.save_overwrite,
617
- synchronize=False,
618
- )
619
- del model_state_dict
620
- barrier()
621
-
622
- # Then the optimizer state.
623
- optim_state_dict = FSDP.optim_state_dict(fsdp_model, optim)
624
- if get_global_rank() == 0:
625
- log.info("Saving optim state...")
626
- save_state_dict(
627
- checkpoint_dir,
628
- "optim.pt",
629
- optim_state_dict,
630
- upload_to=upload_to,
631
- save_overwrite=self.cfg.save_overwrite,
632
- synchronize=False,
633
- )
634
- del optim_state_dict
635
- barrier()
636
-
637
- # Save trainer state.
638
- if get_global_rank() == 0:
639
- log.info("Saving trainer state...")
640
- save_state_dict(
641
- checkpoint_dir,
642
- "train.pt",
643
- trainer_state,
644
- upload_to=upload_to,
645
- save_overwrite=self.cfg.save_overwrite,
646
- synchronize=False,
647
- )
648
- # Save config.
649
- self._save_config(checkpoint_dir, upload_to=upload_to)
650
-
651
- def restore_checkpoint(
652
- self,
653
- load_path: PathOrStr,
654
- fsdp_model: FSDP,
655
- optim: Optimizer,
656
- *,
657
- local_cache: Optional[PathOrStr] = None,
658
- load_optimizer_state: bool = True,
659
- ) -> Dict[str, Any]:
660
- with FSDP.state_dict_type(
661
- fsdp_model,
662
- state_dict_type=StateDictType.FULL_STATE_DICT,
663
- state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
664
- optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
665
- ):
666
- with torch.no_grad():
667
- # fill everything with NaN, so we can check afterwards that every parameter has been restored
668
- for module_name, module in fsdp_model.named_modules():
669
- if not isinstance(module, FSDP):
670
- continue
671
- for param in module.params:
672
- param.fill_(torch.nan)
673
-
674
- # restore params from checkpoint
675
- state_dict_to_load = load_state_dict(
676
- load_path, "model.pt", local_cache=local_cache, map_location="cpu"
677
- )
678
- (
679
- state_dict_to_load,
680
- og_keys_to_new,
681
- ) = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)
682
-
683
- for module_name, module in fsdp_model.named_modules():
684
- if not isinstance(module, FSDP):
685
- continue
686
- for param in module.params:
687
- assert param._is_flat_param
688
- for fqn, spi in zip(param._fqns, param._shard_param_infos):
689
- if not spi.in_shard:
690
- continue
691
- key = f"{module_name}.{fqn}"
692
- key = key.replace("_fsdp_wrapped_module.", "")
693
- key = key.lstrip(".")
694
- t = state_dict_to_load[key]
695
- t = t.flatten()
696
- param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
697
- t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
698
- )
699
-
700
- # make sure that every parameter has been restored
701
- for module_name, module in fsdp_model.named_modules():
702
- if not isinstance(module, FSDP):
703
- continue
704
- for param in module.params:
705
- if torch.isnan(param).any():
706
- raise ValueError(
707
- f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
708
- )
709
-
710
- # Load optimizer state.
711
- if load_optimizer_state:
712
- optim_state_dict_to_load = load_state_dict(
713
- load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
714
- )
715
- optim_state_dict_to_load = self._make_optim_state_dict_compatible(
716
- optim_state_dict_to_load,
717
- og_keys_to_new,
718
- )
719
- load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
720
- del optim_state_dict_to_load
721
-
722
- # Load other state.
723
- try:
724
- trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
725
- except FileNotFoundError:
726
- # for backwards compatibility
727
- trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
728
- barrier()
729
- return trainer_state
730
-
731
- def _make_optim_state_dict_compatible(
732
- self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
733
- ) -> Dict[str, Any]:
734
- # This state dict comes in two forms: one where the state keys are integers and one where the
735
- # keys are fully qualified parameter names. The latter case is easier to deal with here so we
736
- # first transform the integer key form into the FQN key form.
737
- if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
738
- id_to_fqn: Dict[int, str] = {}
739
- for group in optim_state_dict["param_groups"]:
740
- new_param_names = []
741
- for fqn, id in zip(group["param_names"], group["params"]):
742
- fqn = fqn.replace("_fsdp_wrapped_module.", "")
743
- id_to_fqn[id] = fqn
744
- new_param_names.append(fqn)
745
- group["param_names"] = new_param_names
746
- group["params"] = new_param_names
747
- for id in list(optim_state_dict["state"].keys()):
748
- optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
749
- else:
750
- # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
751
- for group in optim_state_dict["param_groups"]:
752
- group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
753
- group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
754
- assert group["param_names"] == group["params"]
755
- for key in list(optim_state_dict["state"].keys()):
756
- optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
757
- "state"
758
- ].pop(key)
759
-
760
- # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
761
- # First fix param names in the state.
762
- for og_key, new_keys in og_keys_to_new.items():
763
- og_state = optim_state_dict["state"].pop(og_key, None)
764
- if og_state is None:
765
- continue
766
- for i, new_key in enumerate(new_keys):
767
- if i == len(new_keys) - 1:
768
- optim_state_dict["state"][new_key] = og_state
769
- else:
770
- optim_state_dict["state"][new_key] = deepcopy(og_state)
771
- # Now fix param names in the param groups.
772
- for group in optim_state_dict["param_groups"]:
773
- og_names = group["params"]
774
- new_names = []
775
- for og_key in og_names:
776
- for new_key in og_keys_to_new[og_key]:
777
- new_names.append(new_key)
778
- group["params"] = new_names
779
- group["param_names"] = new_names
780
-
781
- return optim_state_dict
782
-
783
- def load_checkpoint(
784
- self,
785
- load_path: PathOrStr,
786
- *,
787
- local_cache: Optional[PathOrStr] = None,
788
- load_optimizer_state: bool = True,
789
- device: Optional[torch.device] = None,
790
- ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
791
- device = device if device is not None else torch.device("cpu")
792
- model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) # type: ignore
793
- optim_state = None
794
- if load_optimizer_state:
795
- optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) # type: ignore
796
- return model_state, optim_state
797
-
798
-
799
- class TorchNewStyleShardedCheckpointer(Checkpointer):
800
- """
801
- A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
802
- """
803
-
804
- def save_checkpoint(
805
- self,
806
- dir: PathOrStr,
807
- fsdp_model: FSDP,
808
- optim: Optimizer,
809
- trainer_state: Dict[str, Any],
810
- *,
811
- upload_to: Optional[str] = None,
812
- ) -> None:
813
- with self._temporary_wd(dir) as checkpoint_dir:
814
- # Save model and optim state.
815
- save_fsdp_model_and_optim_state(
816
- checkpoint_dir,
817
- fsdp_model,
818
- optim,
819
- upload_to=upload_to,
820
- save_overwrite=self.cfg.save_overwrite,
821
- )
822
-
823
- # Save trainer state.
824
- log.info("Saving trainer state...")
825
- save_state_dict(
826
- checkpoint_dir,
827
- f"train/rank{get_global_rank()}.pt",
828
- trainer_state,
829
- upload_to=upload_to,
830
- save_overwrite=self.cfg.save_overwrite,
831
- )
832
-
833
- # Save config.
834
- self._save_config(checkpoint_dir, upload_to=upload_to)
835
-
836
- def restore_checkpoint(
837
- self,
838
- load_path: PathOrStr,
839
- fsdp_model: FSDP,
840
- optim: Optimizer,
841
- *,
842
- local_cache: Optional[PathOrStr] = None,
843
- load_optimizer_state: bool = True,
844
- ) -> Dict[str, Any]:
845
- # Load model and optimizer state in place.
846
- log.info("Loading model and optimizer state...")
847
- load_fsdp_model_and_optim_state(
848
- load_path,
849
- fsdp_model,
850
- optim,
851
- local_cache=local_cache,
852
- load_optimizer_state=load_optimizer_state,
853
- )
854
-
855
- # Load trainer state dict.
856
- log.info("Loading trainer state...")
857
- try:
858
- trainer_state = load_state_dict(
859
- load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
860
- )
861
- except FileNotFoundError:
862
- # Fall back to rank 0 train state.
863
- # This can happen when we're restoring a checkpoint with a different world size.
864
- trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
865
- barrier()
866
- return trainer_state
867
-
868
-
869
- class TorchLegacyShardedCheckpointer(Checkpointer):
870
- """
871
- A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
872
- and optim state.
873
-
874
- The world size must be kept consistent when using this checkpointer.
875
- """
876
-
877
- def save_checkpoint(
878
- self,
879
- dir: PathOrStr,
880
- fsdp_model: FSDP,
881
- optim: Optimizer,
882
- trainer_state: Dict[str, Any],
883
- *,
884
- upload_to: Optional[str] = None,
885
- ) -> None:
886
- with self._temporary_wd(dir) as checkpoint_dir:
887
- with FSDP.state_dict_type(
888
- fsdp_model,
889
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
890
- state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
891
- optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
892
- ):
893
- state_dict = {
894
- "model": fsdp_model.state_dict(),
895
- "optim": FSDP.optim_state_dict(fsdp_model, optim),
896
- **trainer_state,
897
- }
898
- save_state_dict(
899
- checkpoint_dir,
900
- f"rank{get_global_rank()}.pt",
901
- state_dict,
902
- upload_to=upload_to,
903
- save_overwrite=self.cfg.save_overwrite,
904
- )
905
-
906
- # Save config.
907
- self._save_config(checkpoint_dir, upload_to=upload_to)
908
-
909
- def restore_checkpoint(
910
- self,
911
- load_path: PathOrStr,
912
- fsdp_model: FSDP,
913
- optim: Optimizer,
914
- *,
915
- local_cache: Optional[PathOrStr] = None,
916
- load_optimizer_state: bool = True,
917
- ) -> Dict[str, Any]:
918
- with FSDP.state_dict_type(
919
- fsdp_model,
920
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
921
- state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
922
- optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
923
- ):
924
- # Deserialize state dict.
925
- state_dict = load_state_dict(
926
- load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
927
- )
928
-
929
- # Load model and optimizer state.
930
- log.info("Loading model state...")
931
- fsdp_model.load_state_dict(state_dict["model"])
932
- del state_dict["model"]
933
- if load_optimizer_state:
934
- log.info("Loading optimizer state...")
935
- load_fsdp_optim_state(fsdp_model, optim, state_dict["optim"])
936
- del state_dict["optim"]
937
-
938
- barrier()
939
- return state_dict
940
-
941
- def unshard_checkpoint(
942
- self,
943
- load_path: PathOrStr,
944
- *,
945
- local_cache: Optional[PathOrStr] = None,
946
- load_optimizer_state: bool = True,
947
- load_trainer_state: bool = True,
948
- device: Optional[torch.device] = None,
949
- ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
950
- assert local_cache is None, "this method currently only supports local files"
951
- full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
952
- model_state = full_state_dict.pop("model")
953
- optim_state = full_state_dict.pop("optim")
954
- return (
955
- model_state,
956
- optim_state if load_optimizer_state else None,
957
- full_state_dict if load_trainer_state else None,
958
- )
959
-
960
- def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
961
- key = tuple() if key is None else key
962
- if isinstance(state, (list, tuple, set)):
963
- for i, sub_state in enumerate(state):
964
- self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
965
- elif isinstance(state, dict):
966
- for name in state.keys():
967
- self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
968
- elif isinstance(state, ShardedTensor):
969
- self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
970
- return
971
- else:
972
- return
973
-
974
- def _get_shard_placement_and_rank_sizes(
975
- self, shards_metadata: List[ShardMetadata], world_size: int
976
- ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
977
- def shard_size(shard_md):
978
- return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
979
-
980
- rank_sizes = [0 for _ in range(world_size)]
981
- shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
982
- for shard_md in shards_metadata:
983
- shard_rank = cast(_remote_device, shard_md.placement).rank()
984
- assert shard_rank is not None
985
- if shard_rank >= world_size:
986
- raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")
987
-
988
- shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
989
- rank_sizes[shard_rank] += shard_size(shard_md)
990
-
991
- return shard_placement, rank_sizes
992
-
993
- def _copy_sharded_tensor_to_shared_mem(
994
- self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
995
- ) -> Any:
996
- shard0_md = sharded_tensor.metadata()
997
- shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
998
- shard0_md.shards_metadata, world_size
999
- )
1000
-
1001
- rank_size = rank_sizes[rank]
1002
- assert rank_size >= 0
1003
- if rank_size == 0:
1004
- return
1005
-
1006
- assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1007
- numpy_type = np.float32
1008
-
1009
- sharded_memory_name = "-".join(key + (str(rank),))
1010
-
1011
- shm = shared_memory.SharedMemory(
1012
- create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
1013
- )
1014
- np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1015
-
1016
- for local_shard in sharded_tensor.local_shards():
1017
- shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
1018
- assert shard_rank == rank
1019
-
1020
- src = local_shard.tensor.flatten()
1021
- shard_offset = shard_placement[local_shard.metadata][1]
1022
-
1023
- np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()
1024
-
1025
- shm.close()
1026
-
1027
- def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
1028
- shard_number = int(shard_filepath.name[4:-3])
1029
- log.info("Starting unsharding shard number %d to shared memory", shard_number)
1030
-
1031
- with self._patch_sharded_tensor_load():
1032
- shard = torch.load(shard_filepath, map_location="cpu")
1033
- log.debug("Done loading shard number %d", shard_number)
1034
-
1035
- self._copy_sharded_tensors_to_shared_mem(
1036
- shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
1037
- )
1038
- log.info("Done unsharding shard number %d to shared memory", shard_number)
1039
-
1040
- def _unshard_using_sharded_mem(
1041
- self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
1042
- ) -> Any:
1043
- return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))
1044
-
1045
- def _unshard_state_using_shared_mem(
1046
- self, state: Any, world_size: int, device: torch.device, key: Tuple
1047
- ) -> Any:
1048
- if isinstance(state, (list, tuple, set)):
1049
- return state.__class__(
1050
- self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
1051
- for i, sub_state in enumerate(state)
1052
- )
1053
- elif isinstance(state, dict):
1054
- return {
1055
- name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
1056
- for name in state.keys()
1057
- }
1058
- elif isinstance(state, ShardedTensor):
1059
- return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
1060
- elif isinstance(state, torch.Tensor):
1061
- return state.to(device=device)
1062
- else:
1063
- return state
1064
-
1065
- def _unshard_tensor_using_shared_mem(
1066
- self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
1067
- ) -> torch.Tensor:
1068
- shard0_md = sharded_tensor.metadata()
1069
-
1070
- def shard_size(shard_md):
1071
- return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1072
-
1073
- shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
1074
- shard0_md.shards_metadata, world_size
1075
- )
1076
-
1077
- assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1078
- numpy_type = np.float32
1079
-
1080
- out = torch.empty(
1081
- *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
1082
- )
1083
- dims = len(sharded_tensor.metadata().size)
1084
- for shard_md, (rank, rank_offset) in shard_placement.items():
1085
- if rank >= world_size:
1086
- raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")
1087
-
1088
- sharded_memory_name = "-".join(key + (str(rank),))
1089
- shm = shared_memory.SharedMemory(name=sharded_memory_name)
1090
-
1091
- rank_size = rank_sizes[rank]
1092
- assert rank_size >= 0
1093
- if rank_size == 0:
1094
- continue
1095
-
1096
- np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1097
-
1098
- tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
1099
- tensor = tensor.view(shard_md.shard_sizes)
1100
-
1101
- out_narrow_view = out
1102
- for dim in range(dims):
1103
- out_narrow_view = out_narrow_view.narrow(
1104
- dim,
1105
- shard_md.shard_offsets[dim],
1106
- shard_md.shard_sizes[dim],
1107
- )
1108
-
1109
- out_narrow_view.copy_(tensor)
1110
-
1111
- shm.close()
1112
- shm.unlink()
1113
-
1114
- return out
1115
-
1116
- @contextmanager
1117
- def _patch_sharded_tensor_load(self):
1118
- """
1119
- Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
1120
- """
1121
-
1122
- def _rebuild_from_type_v2_monkey(func, new_type, args, state):
1123
- ret = func(*args)
1124
- if type(ret) is not new_type:
1125
- ret = ret.as_subclass(new_type)
1126
-
1127
- # Shortcut the construction of ShardedTensor
1128
- # This is in the top 5 of my worst hacks.
1129
- if isinstance(ret, ShardedTensor):
1130
- ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
1131
- return ret
1132
-
1133
- # The rest of this function ought to be in the top 5 of somebody else's worst hacks.
1134
- # Tensor does define __setstate__ even though it doesn't define
1135
- # __getstate__. So only use __setstate__ if it is NOT the one defined
1136
- # on Tensor
1137
- if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
1138
- ret.__setstate__(state)
1139
- else:
1140
- ret = torch._utils._set_obj_state(ret, state)
1141
- return ret
1142
-
1143
- original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
1144
- try:
1145
- torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
1146
- yield
1147
- finally:
1148
- torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2
1149
-
1150
- def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
1151
- """
1152
- The current unsharding implementation consists of:
1153
-
1154
- 1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
1155
- 2. Loading 1 shard on the main process as a base unsharded object.
1156
- 3. Using the sharded tensors in shared memory to populate the base unsharded object.
1157
-
1158
- This implementation replaced a prior implementation that instead loaded
1159
- all shards using threads, because that implementation turned out to
1160
- be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
1161
- The current implementation is slower than the old one in many scenarios,
1162
- but is significantly faster in the above mentioned case (e.g. 30 minutes)
1163
- if there are enough CPUs.
1164
- """
1165
-
1166
- input_dir = Path(input_dir)
1167
- skip_keys = skip_keys or set()
1168
-
1169
- shard_filepaths = list(input_dir.glob("rank*.pt"))
1170
- world_size = len(shard_filepaths)
1171
- if world_size == 0:
1172
- raise RuntimeError("No shards found for unsharding")
1173
-
1174
- log.info("Number of shards: %d", world_size)
1175
- shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
1176
- min_ram_required_estimate_gb = shard_size_gb * world_size
1177
- log.info(
1178
- "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
1179
- )
1180
-
1181
- log.info("Copying sharded tensors to shared memory using multiple processes")
1182
- # Copy sharded data to shared memory using multiple processes, so this process can load
1183
- # from memory rather than disk. We spawn a new process instead of forking since shared memory
1184
- # appears to get deleted when forked processes end for some reason.
1185
- executor = ProcessPoolExecutor(
1186
- mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
1187
- )
1188
- futures = []
1189
- for shard_filepath in shard_filepaths:
1190
- shard_rank = int(shard_filepath.name[4:-3])
1191
-
1192
- if shard_rank >= world_size:
1193
- raise RuntimeError(
1194
- f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
1195
- )
1196
-
1197
- futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))
1198
-
1199
- for f in as_completed(futures):
1200
- f.result()
1201
- executor.shutdown()
1202
-
1203
- log.info("Loading a shard on the main process to be unsharded state")
1204
- with self._patch_sharded_tensor_load():
1205
- state = torch.load(shard_filepaths[0], map_location="cpu")
1206
-
1207
- for key in skip_keys:
1208
- if key in state:
1209
- del state[key]
1210
-
1211
- log.info("Unsharding from %d shards ...", world_size)
1212
- return self._unshard_using_sharded_mem(state, world_size, device, input_dir)
1213
-
1214
-
1215
- @dataclass
1216
- class _LocalShardedCheckpointerMetadata(BaseConfig):
1217
- world_size: int = field(default_factory=get_world_size)
1218
-
1219
-
1220
- @dataclass
1221
- class _FlatParamShard:
1222
- full_shape: torch.Size
1223
- shard_offsets: Tuple[int, int]
1224
- shard_data: Optional[torch.Tensor]
1225
-
1226
- def copy_into(self, full_tensor: torch.Tensor) -> None:
1227
- assert self.shard_data is not None
1228
- full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
1229
- assert self.shard_data.shape == full_tensor_shard_view.shape
1230
- full_tensor_shard_view.copy_(self.shard_data)
1231
-
1232
-
1233
- class LocalShardedCheckpointer(Checkpointer):
1234
- """
1235
- A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
1236
- The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.
1237
-
1238
- The world size must be kept consistent when using this checkpointer. However, you can easily
1239
- reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
1240
- using :meth:`unshard_checkpoint()` (no distributed initialization required).
1241
- """
1242
-
1243
- # These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
1244
- _FLAT_PARAM_METADATA_TO_SAVE = (
1245
- "_fqns",
1246
- "_shard_param_offsets",
1247
- "_shard_indices",
1248
- "_numels",
1249
- "_numels_with_padding",
1250
- "_shapes",
1251
- "_shard_numel_padded",
1252
- "_shard_param_infos",
1253
- )
1254
-
1255
- def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
1256
- """
1257
- Returns a list of FSDP modules with their FQN.
1258
- """
1259
- modules = []
1260
- for name, module in fsdp_model.named_modules():
1261
- if isinstance(module, FSDP):
1262
- modules.append((name, module))
1263
- return modules
1264
-
1265
- def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
1266
- from torch.distributed.fsdp._runtime_utils import _lazy_init
1267
-
1268
- # TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
1269
- # an FSDP state dict through the built-in methods.
1270
- if torch.cuda.is_available():
1271
- torch.cuda.synchronize()
1272
- _lazy_init(fsdp_model, fsdp_model)
1273
-
1274
- def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
1275
- if version.parse(torch.__version__) < version.parse("2.1.0"):
1276
- return fsdp_model._handles # type: ignore
1277
- elif version.parse(torch.__version__) < version.parse("2.3.0"):
1278
- # Handle could be None if the FSDP wrapper doesn't manage any parameters.
1279
- if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
1280
- return [fsdp_model._handle] # type: ignore
1281
- else:
1282
- return []
1283
- else:
1284
- # Need to verify FSDP internals with newer versions.
1285
- raise NotImplementedError
1286
-
1287
- @torch.no_grad()
1288
- def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
1289
- self._prepare_fsdp_model(fsdp_model)
1290
- module_data = []
1291
- for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
1292
- handle_data = []
1293
- for handle in self._fsdp_handles(fsdp_module):
1294
- data: Dict[str, Any] = {}
1295
- # This is a `FlatParameter` instance.
1296
- # See `torch.distributed.fsdp.flat_param` for the API.
1297
- flat_param = handle.flat_param
1298
- data["flat_param.data"] = flat_param.detach()
1299
- for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1300
- if hasattr(flat_param, key):
1301
- data[f"flat_param.{key}"] = getattr(flat_param, key)
1302
- handle_data.append(data)
1303
- module_data.append({"handles": handle_data, "name": module_fqn})
1304
- return {"modules": module_data}
1305
-
1306
- @torch.no_grad()
1307
- def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
1308
- """Load the state produced from `self._get_flat_param_state_to_save()`."""
1309
- self._prepare_fsdp_model(fsdp_model)
1310
- fsdp_modules = self._fsdp_modules(fsdp_model)
1311
- assert len(model_state["modules"]) == len(fsdp_modules)
1312
- for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
1313
- handles = self._fsdp_handles(fsdp_module)
1314
- assert len(handles) == len(module_data["handles"])
1315
- for handle, data in zip(handles, module_data["handles"]):
1316
- flat_param = handle.flat_param
1317
- # Make sure metadata matches.
1318
- for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1319
- if hasattr(flat_param, key):
1320
- assert getattr(flat_param, key) == data[f"flat_param.{key}"]
1321
- # Load the flat sharded data.
1322
- flat_param.copy_(data["flat_param.data"])
1323
-
1324
- def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
1325
- if get_fs_local_rank() == 0:
1326
- log.info("Saving metadata...")
1327
- metadata = _LocalShardedCheckpointerMetadata()
1328
- metadata.save(metadata_path := Path(dir) / "metadata.yaml")
1329
- if upload_to is not None and get_global_rank() == 0:
1330
- upload_target = f"{upload_to}/metadata.yaml"
1331
- log.info(f"Uploading {metadata_path} to {upload_target}")
1332
- upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)
1333
-
1334
- def _load_metadata(
1335
- self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
1336
- ) -> _LocalShardedCheckpointerMetadata:
1337
- metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
1338
- return _LocalShardedCheckpointerMetadata.load(metadata_path)
1339
-
1340
- def save_checkpoint(
1341
- self,
1342
- dir: PathOrStr,
1343
- fsdp_model: FSDP,
1344
- optim: Optimizer,
1345
- trainer_state: Dict[str, Any],
1346
- *,
1347
- upload_to: Optional[str] = None,
1348
- ) -> None:
1349
- with self._temporary_wd(dir) as checkpoint_dir:
1350
- # Gather local FSDP flat params data to save.
1351
- # We also save some flat param metadata like the corresponding fully qualified names (fqns)
1352
- # of each original parameter so we can validate that the sharding is the same when loading
1353
- # one of these checkpoints.
1354
- log.info("Saving local FSDP flat params data...")
1355
- save_state_dict(
1356
- checkpoint_dir,
1357
- f"model/rank{get_global_rank()}.pt",
1358
- self._get_flat_param_state_to_save(fsdp_model),
1359
- upload_to=upload_to,
1360
- save_overwrite=self.cfg.save_overwrite,
1361
- )
1362
-
1363
- # Save optimizer state.
1364
- log.info("Saving local optimizer state...")
1365
- save_state_dict(
1366
- checkpoint_dir,
1367
- f"optim/rank{get_global_rank()}.pt",
1368
- optim.state_dict(),
1369
- upload_to=upload_to,
1370
- save_overwrite=self.cfg.save_overwrite,
1371
- )
1372
-
1373
- # Save trainer state.
1374
- log.info("Saving trainer state...")
1375
- save_state_dict(
1376
- checkpoint_dir,
1377
- f"train/rank{get_global_rank()}.pt",
1378
- trainer_state,
1379
- upload_to=upload_to,
1380
- save_overwrite=self.cfg.save_overwrite,
1381
- )
1382
-
1383
- # Save metadata.
1384
- self._save_metadata(checkpoint_dir, upload_to=upload_to)
1385
-
1386
- # Save config. We do this last b/c the presence of a config in a remote checkpoint
1387
- # "directory" indicates that the folder is valid, as a opposed to a partially
1388
- # uploaded checkpoint directory that failed before completing.
1389
- self._save_config(checkpoint_dir, upload_to=upload_to)
1390
-
1391
- def restore_checkpoint(
1392
- self,
1393
- load_path: PathOrStr,
1394
- fsdp_model: FSDP,
1395
- optim: Optimizer,
1396
- *,
1397
- local_cache: Optional[PathOrStr] = None,
1398
- load_optimizer_state: bool = True,
1399
- ) -> Dict[str, Any]:
1400
- # Load metadata and make sure checkpoint is compatible.
1401
- metadata = self._load_metadata(load_path, local_cache=local_cache)
1402
- assert metadata.world_size == get_world_size()
1403
-
1404
- # Load local FSDP flat param data.
1405
- log.info("Loading local FSDP flat params data...")
1406
- model_state = load_state_dict(
1407
- load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1408
- )
1409
- self._load_flat_param_state(fsdp_model, model_state)
1410
- del model_state
1411
-
1412
- # Load local optim state.
1413
- if load_optimizer_state:
1414
- log.info("Loading local optimizer state...")
1415
- optim_state = load_state_dict(
1416
- load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1417
- )
1418
- # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
1419
- # in every rank, and keep this in the optimizer state. But this causes issues when loading the
1420
- # state since torch sees the state is non-empty for some params which would normally be empty,
1421
- # and then assumes it should have all of the other state tensors for that param, which is doesn't.
1422
- # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
1423
- # Not the end of the world but there's probably a better way around this without resetting
1424
- # the metric.
1425
- for param_id in list(optim_state["state"].keys()):
1426
- state = optim_state["state"][param_id]
1427
- if "grad_norm_exp_avg" in state:
1428
- del state["grad_norm_exp_avg"]
1429
- if len(state) == 0:
1430
- del optim_state["state"][param_id]
1431
- optim.load_state_dict(optim_state)
1432
- del optim_state
1433
-
1434
- # Load local trainer state.
1435
- log.info("Loading local trainer state...")
1436
- trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1437
- barrier()
1438
- return trainer_state
1439
-
1440
- def _iter_flat_param_shards(
1441
- self, model_state: Dict[str, Any]
1442
- ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
1443
- for module_data in model_state["modules"]:
1444
- module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
1445
- for handle in module_data["handles"]:
1446
- flat_data = handle["flat_param.data"]
1447
- if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
1448
- # If there's padding in the flat param it should be on the right.
1449
- assert (flat_data[-num_padding:] == 0).all()
1450
- # NOTE: this changes depending on the torch version, but we don't do a version
1451
- # check since we might be trying to unshard an old checkpoint that was stored
1452
- # with a different torch version than we're currently running with.
1453
- if "flat_param._shard_indices" in handle:
1454
- # torch <=2.0.1
1455
- param_start = handle["flat_param._shard_indices"][0]
1456
- current_flat_index = 0
1457
- for relative_fqn, full_shape, (offset_start, offset_end) in zip(
1458
- handle["flat_param._fqns"][param_start:],
1459
- handle["flat_param._shapes"][param_start:],
1460
- handle["flat_param._shard_param_offsets"],
1461
- ):
1462
- root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1463
- numel_shard = offset_end - offset_start + 1
1464
- flat_param_shard = _FlatParamShard(
1465
- full_shape=full_shape,
1466
- shard_offsets=(offset_start, offset_end),
1467
- shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
1468
- )
1469
- current_flat_index += numel_shard
1470
- yield root_fqn, flat_param_shard
1471
- else:
1472
- # torch >=2.1.0
1473
- for relative_fqn, full_shape, shard_param_info in zip(
1474
- handle["flat_param._fqns"],
1475
- handle["flat_param._shapes"],
1476
- handle["flat_param._shard_param_infos"],
1477
- ):
1478
- if not shard_param_info.in_shard:
1479
- continue
1480
- root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1481
- flat_param_shard = _FlatParamShard(
1482
- full_shape=full_shape,
1483
- shard_offsets=(
1484
- shard_param_info.intra_param_start_idx,
1485
- shard_param_info.intra_param_end_idx,
1486
- ),
1487
- shard_data=flat_data[
1488
- shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
1489
- + shard_param_info.numel_in_shard
1490
- ],
1491
- )
1492
- yield root_fqn, flat_param_shard
1493
-
1494
- def unshard_checkpoint(
1495
- self,
1496
- load_path: PathOrStr,
1497
- *,
1498
- local_cache: Optional[PathOrStr] = None,
1499
- load_optimizer_state: bool = True,
1500
- load_trainer_state: bool = True,
1501
- device: Optional[torch.device] = None,
1502
- ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1503
- device = device or torch.device("cpu")
1504
- metadata = self._load_metadata(load_path, local_cache=local_cache)
1505
-
1506
- # Gather paths model state, potentially downloading them.
1507
- log.info("Gathering model state dicts...")
1508
- model_state_paths = self._gather_state_dict_paths(
1509
- load_path, "model", metadata.world_size, local_cache=local_cache
1510
- )
1511
-
1512
- # Load model state dicts one-by-one, materializing and populating the full parameters as we go.
1513
- log.info("Materializing full parameters...")
1514
- full_model_state: Dict[str, torch.Tensor] = {}
1515
- # We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
1516
- # the full optimizer state below without having to reload the model state dicts.
1517
- flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
1518
- for rank, path in enumerate(model_state_paths):
1519
- log.info(f"Loading shards from rank {rank}...")
1520
- model_state = torch.load(path, map_location="cpu")
1521
- for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
1522
- if root_fqn not in full_model_state:
1523
- log.info(
1524
- f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
1525
- )
1526
- assert flat_param_shard.shard_data is not None
1527
- full_model_state[root_fqn] = torch.empty(
1528
- flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
1529
- )
1530
- # Fill with NaNs so we can validate that the whole parameter has been populated
1531
- # afterwards.
1532
- full_model_state[root_fqn].fill_(torch.nan)
1533
- # Copy over the local shard to the relevant part of the full parameter.
1534
- full_param = full_model_state[root_fqn]
1535
- log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
1536
- flat_param_shard.copy_into(full_param)
1537
- flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)
1538
-
1539
- log.info("Validating full parameters...")
1540
- for key, tensor in full_model_state.items():
1541
- if torch.isnan(tensor).any():
1542
- raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")
1543
-
1544
- trainer_state: Optional[Dict[str, Any]] = None
1545
- if load_trainer_state:
1546
- trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
1547
-
1548
- if not load_optimizer_state:
1549
- return full_model_state, None, trainer_state
1550
-
1551
- log.info("Gathering optim state dicts...")
1552
- optim_state_paths = self._gather_state_dict_paths(
1553
- load_path, "optim", metadata.world_size, local_cache=local_cache
1554
- )
1555
-
1556
- log.info("Materializing full optim state...")
1557
- full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
1558
- fqn_to_id: Dict[str, int] = {}
1559
- id_to_fqn: Dict[int, str] = {}
1560
- for rank, path in enumerate(optim_state_paths):
1561
- log.info(f"Loading sharded optim state from rank {rank}...")
1562
- optim_state = torch.load(path, map_location="cpu")
1563
-
1564
- # Initialize param groups.
1565
- # We assume parameter groups are the same across all ranks.
1566
- # The only thing that differs across ranks is the state for each local sharded param.
1567
- if "param_groups" not in full_optim_state:
1568
- full_optim_state["param_groups"] = optim_state["param_groups"]
1569
- else:
1570
- assert full_optim_state["param_groups"] == optim_state["param_groups"]
1571
-
1572
- # Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
1573
- if not fqn_to_id or not id_to_fqn:
1574
- for group in full_optim_state["param_groups"]:
1575
- for fqn, id in zip(group["param_names"], group["params"]):
1576
- fqn = fqn.replace("_fsdp_wrapped_module.", "")
1577
- fqn_to_id[fqn] = id
1578
- id_to_fqn[id] = fqn
1579
-
1580
- # Iterate over local shard state and copy into the full state.
1581
- for id, shard_state in optim_state["state"].items():
1582
- fqn = id_to_fqn[id]
1583
- flat_param_shard = flat_params_data[rank].get(fqn) # type: ignore[assignment]
1584
- full_state = full_optim_state["state"][id]
1585
- for key, shard_value in shard_state.items():
1586
- assert isinstance(shard_value, torch.Tensor)
1587
- if shard_value.shape == torch.Size([]):
1588
- # Add singleton tensors directly to full state. These should be the same across
1589
- # all ranks.
1590
- assert key in ("step", "grad_norm_exp_avg") # sanity check
1591
- if key not in full_state:
1592
- full_state[key] = shard_value.to(device)
1593
- else:
1594
- assert full_state[key] == shard_value
1595
- else:
1596
- # Otherwise we have a sharded param state.
1597
- # If the corresponding full param state hasn't been materialized yet, do so now.
1598
- assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
1599
- if key not in full_state:
1600
- log.info(
1601
- f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
1602
- )
1603
- full_state[key] = torch.empty(
1604
- flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
1605
- )
1606
- full_state_value = full_state[key]
1607
-
1608
- # Copy over the local shard state to the relevant part of the full parameter state.
1609
- log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
1610
- replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)
1611
-
1612
- # Lastly, clean up the parameter names in param groups.
1613
- for group in full_optim_state["param_groups"]:
1614
- group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]
1615
-
1616
- return full_model_state, full_optim_state, trainer_state
1617
-
1618
- def _get_state_dict_path(
1619
- self,
1620
- load_path: PathOrStr,
1621
- state_dict_type: str,
1622
- rank: int,
1623
- *,
1624
- local_cache: Optional[PathOrStr] = None,
1625
- progress=None,
1626
- ) -> Tuple[int, Path]:
1627
- fname = f"{state_dict_type}/rank{rank}.pt"
1628
- return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)
1629
-
1630
- def _gather_state_dict_paths(
1631
- self,
1632
- load_path: PathOrStr,
1633
- state_dict_type: str,
1634
- world_size: int,
1635
- *,
1636
- local_cache: Optional[PathOrStr] = None,
1637
- ) -> List[Path]:
1638
- progress = get_progress_bar()
1639
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
1640
- futures = []
1641
- for rank in range(world_size):
1642
- future = executor.submit(
1643
- self._get_state_dict_path,
1644
- load_path,
1645
- state_dict_type,
1646
- rank,
1647
- local_cache=local_cache,
1648
- progress=progress,
1649
- )
1650
- futures.append(future)
1651
-
1652
- results: Dict[int, Path] = {}
1653
- for future in as_completed(futures):
1654
- rank, path = future.result()
1655
- results[rank] = path
1656
-
1657
- return [results[rank] for rank in range(world_size)]
1658
-
1659
-
1660
- class OlmoCoreCheckpointer(Checkpointer):
1661
- def save_checkpoint(
1662
- self,
1663
- dir: PathOrStr,
1664
- fsdp_model: FSDP,
1665
- optim: Optimizer,
1666
- trainer_state: Dict[str, Any],
1667
- *,
1668
- upload_to: Optional[str] = None,
1669
- ) -> None:
1670
- from olmo_core.distributed.checkpoint import ( # type: ignore
1671
- save_model_and_optim_state,
1672
- )
1673
-
1674
- with self._temporary_wd(dir) as checkpoint_dir:
1675
- log.info("Saving model and optim state...")
1676
- save_model_and_optim_state(checkpoint_dir, fsdp_model, optim, save_overwrite=self.cfg.save_overwrite)
1677
- if upload_to is not None and get_fs_local_rank() == 0:
1678
- for path in Path(checkpoint_dir).glob("**/*"):
1679
- if not path.is_file():
1680
- continue
1681
- upload_target = f"{upload_to.rstrip('/')}/{path.relative_to(checkpoint_dir)}"
1682
- log.info(f"Uploading {path} to {upload_target}...")
1683
- upload(path, upload_target, save_overwrite=self.cfg.save_overwrite)
1684
-
1685
- log.info("Saving trainer state...")
1686
- save_state_dict(
1687
- checkpoint_dir,
1688
- f"train/rank{get_global_rank()}.pt",
1689
- trainer_state,
1690
- save_overwrite=self.cfg.save_overwrite,
1691
- upload_to=upload_to,
1692
- )
1693
-
1694
- self._save_config(checkpoint_dir, upload_to=upload_to)
1695
-
1696
- def restore_checkpoint(
1697
- self,
1698
- load_path: PathOrStr,
1699
- fsdp_model: FSDP,
1700
- optim: Optimizer,
1701
- *,
1702
- local_cache: Optional[PathOrStr] = None,
1703
- load_optimizer_state: bool = True,
1704
- ) -> Dict[str, Any]:
1705
- from olmo_core.distributed.checkpoint import ( # type: ignore
1706
- load_model_and_optim_state,
1707
- )
1708
-
1709
- log.info("Loading model and optim state...")
1710
- load_model_and_optim_state(load_path, fsdp_model, optim if load_optimizer_state else None)
1711
-
1712
- log.info("Loading trainer state...")
1713
- trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1714
-
1715
- barrier()
1716
- return trainer_state
1717
-
1718
-
1719
- def build_sharded_checkpointer(
1720
- cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
1721
- ) -> Checkpointer:
1722
- name = name or cfg.sharded_checkpointer
1723
- if name == ShardedCheckpointerType.torch_new:
1724
- return TorchNewStyleShardedCheckpointer(cfg)
1725
- elif name == ShardedCheckpointerType.torch_legacy:
1726
- return TorchLegacyShardedCheckpointer(cfg)
1727
- elif name == ShardedCheckpointerType.local:
1728
- return LocalShardedCheckpointer(cfg)
1729
- elif name == ShardedCheckpointerType.olmo_core:
1730
- return OlmoCoreCheckpointer(cfg)
1731
- else:
1732
- raise NotImplementedError(name)