File size: 6,342 Bytes
27ca8b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from pathlib import Path
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Union
from typing_extensions import override
from functools import wraps
import os
from wandb_osh.hooks import TriggerWandbSyncHook
import time
from lightning.pytorch.loggers.wandb import WandbLogger, _scan_checkpoints, ModelCheckpoint, Tensor
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.types import _PATH
if TYPE_CHECKING:
from wandb.sdk.lib import RunDisabled
from wandb.wandb_run import Run
class SpaceEfficientWandbLogger(WandbLogger):
"""
A wandb logger that by default overrides artifacts to save space, instead of creating new version.
A variable expiration_days can be set to control how long older versions of artifacts are kept.
By default, the latest version is kept indefinitely, while older versions are kept for 5 days.
"""
def __init__(
self,
name: Optional[str] = None,
save_dir: _PATH = ".",
version: Optional[str] = None,
offline: bool = False,
dir: Optional[_PATH] = None,
id: Optional[str] = None,
anonymous: Optional[bool] = None,
project: Optional[str] = None,
log_model: Union[Literal["all"], bool] = False,
experiment: Union["Run", "RunDisabled", None] = None,
prefix: str = "",
checkpoint_name: Optional[str] = None,
expiration_days: Optional[int] = 5,
**kwargs: Any,
) -> None:
super().__init__(
name=name,
save_dir=save_dir,
version=version,
offline=False,
dir=dir,
id=id,
anonymous=anonymous,
project=project,
log_model=log_model,
experiment=experiment,
prefix=prefix,
checkpoint_name=checkpoint_name,
**kwargs,
)
super().__init__(
name=name,
save_dir=save_dir,
version=version,
offline=offline,
dir=dir,
id=id,
anonymous=anonymous,
project=project,
log_model=log_model,
experiment=experiment,
prefix=prefix,
checkpoint_name=checkpoint_name,
**kwargs,
)
self.expiration_days = expiration_days
self._last_artifacts = []
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
import wandb
# get checkpoints to be saved with associated score
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
# log iteratively all new checkpoints
artifacts = []
for t, p, s, tag in checkpoints:
metadata = {
"score": s.item() if isinstance(s, Tensor) else s,
"original_filename": Path(p).name,
checkpoint_callback.__class__.__name__: {
k: getattr(checkpoint_callback, k)
for k in [
"monitor",
"mode",
"save_last",
"save_top_k",
"save_weights_only",
"_every_n_train_steps",
]
# ensure it does not break if `ModelCheckpoint` args change
if hasattr(checkpoint_callback, k)
},
}
if not self._checkpoint_name:
self._checkpoint_name = f"model-{self.experiment.id}"
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
artifact.add_file(p, name="model.ckpt")
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t
artifacts.append(artifact)
for artifact in self._last_artifacts:
if not self._offline:
artifact.wait()
artifact.ttl = timedelta(days=self.expiration_days)
artifact.save()
self._last_artifacts = artifacts
class OfflineWandbLogger(SpaceEfficientWandbLogger):
"""
Wraps WandbLogger to trigger offline sync hook occasionally.
This is useful when running on slurm clusters, many of which
only has internet on login nodes, not compute nodes.
"""
def __init__(
self,
name: Optional[str] = None,
save_dir: _PATH = ".",
version: Optional[str] = None,
offline: bool = False,
dir: Optional[_PATH] = None,
id: Optional[str] = None,
anonymous: Optional[bool] = None,
project: Optional[str] = None,
log_model: Union[Literal["all"], bool] = False,
experiment: Union["Run", "RunDisabled", None] = None,
prefix: str = "",
checkpoint_name: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(
name=name,
save_dir=save_dir,
version=version,
offline=False,
dir=dir,
id=id,
anonymous=anonymous,
project=project,
log_model=log_model,
experiment=experiment,
prefix=prefix,
checkpoint_name=checkpoint_name,
**kwargs,
)
self._offline = offline
communication_dir = Path(".wandb_osh_command_dir")
communication_dir.mkdir(parents=True, exist_ok=True)
self.trigger_sync = TriggerWandbSyncHook(communication_dir)
self.last_sync_time = 0.0
self.min_sync_interval = 60
self.wandb_dir = os.path.join(self._save_dir, "wandb/latest-run")
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
out = super().log_metrics(metrics, step)
if time.time() - self.last_sync_time > self.min_sync_interval:
self.trigger_sync(self.wandb_dir)
self.last_sync_time = time.time()
return out
|