|
from pathlib import Path |
|
from datetime import timedelta |
|
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Union |
|
from typing_extensions import override |
|
from functools import wraps |
|
import os |
|
from wandb_osh.hooks import TriggerWandbSyncHook |
|
import time |
|
from lightning.pytorch.loggers.wandb import WandbLogger, _scan_checkpoints, ModelCheckpoint, Tensor |
|
from lightning.pytorch.utilities.rank_zero import rank_zero_only |
|
from lightning.fabric.utilities.types import _PATH |
|
|
|
|
|
if TYPE_CHECKING: |
|
from wandb.sdk.lib import RunDisabled |
|
from wandb.wandb_run import Run |
|
|
|
|
|
class SpaceEfficientWandbLogger(WandbLogger): |
|
""" |
|
A wandb logger that by default overrides artifacts to save space, instead of creating new version. |
|
A variable expiration_days can be set to control how long older versions of artifacts are kept. |
|
By default, the latest version is kept indefinitely, while older versions are kept for 5 days. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
name: Optional[str] = None, |
|
save_dir: _PATH = ".", |
|
version: Optional[str] = None, |
|
offline: bool = False, |
|
dir: Optional[_PATH] = None, |
|
id: Optional[str] = None, |
|
anonymous: Optional[bool] = None, |
|
project: Optional[str] = None, |
|
log_model: Union[Literal["all"], bool] = False, |
|
experiment: Union["Run", "RunDisabled", None] = None, |
|
prefix: str = "", |
|
checkpoint_name: Optional[str] = None, |
|
expiration_days: Optional[int] = 5, |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__( |
|
name=name, |
|
save_dir=save_dir, |
|
version=version, |
|
offline=False, |
|
dir=dir, |
|
id=id, |
|
anonymous=anonymous, |
|
project=project, |
|
log_model=log_model, |
|
experiment=experiment, |
|
prefix=prefix, |
|
checkpoint_name=checkpoint_name, |
|
**kwargs, |
|
) |
|
|
|
super().__init__( |
|
name=name, |
|
save_dir=save_dir, |
|
version=version, |
|
offline=offline, |
|
dir=dir, |
|
id=id, |
|
anonymous=anonymous, |
|
project=project, |
|
log_model=log_model, |
|
experiment=experiment, |
|
prefix=prefix, |
|
checkpoint_name=checkpoint_name, |
|
**kwargs, |
|
) |
|
self.expiration_days = expiration_days |
|
self._last_artifacts = [] |
|
|
|
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: |
|
import wandb |
|
|
|
|
|
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) |
|
|
|
|
|
artifacts = [] |
|
for t, p, s, tag in checkpoints: |
|
metadata = { |
|
"score": s.item() if isinstance(s, Tensor) else s, |
|
"original_filename": Path(p).name, |
|
checkpoint_callback.__class__.__name__: { |
|
k: getattr(checkpoint_callback, k) |
|
for k in [ |
|
"monitor", |
|
"mode", |
|
"save_last", |
|
"save_top_k", |
|
"save_weights_only", |
|
"_every_n_train_steps", |
|
] |
|
|
|
if hasattr(checkpoint_callback, k) |
|
}, |
|
} |
|
if not self._checkpoint_name: |
|
self._checkpoint_name = f"model-{self.experiment.id}" |
|
|
|
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) |
|
artifact.add_file(p, name="model.ckpt") |
|
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] |
|
self.experiment.log_artifact(artifact, aliases=aliases) |
|
|
|
self._logged_model_time[p] = t |
|
artifacts.append(artifact) |
|
|
|
for artifact in self._last_artifacts: |
|
if not self._offline: |
|
artifact.wait() |
|
artifact.ttl = timedelta(days=self.expiration_days) |
|
artifact.save() |
|
self._last_artifacts = artifacts |
|
|
|
|
|
class OfflineWandbLogger(SpaceEfficientWandbLogger): |
|
""" |
|
Wraps WandbLogger to trigger offline sync hook occasionally. |
|
This is useful when running on slurm clusters, many of which |
|
only has internet on login nodes, not compute nodes. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
name: Optional[str] = None, |
|
save_dir: _PATH = ".", |
|
version: Optional[str] = None, |
|
offline: bool = False, |
|
dir: Optional[_PATH] = None, |
|
id: Optional[str] = None, |
|
anonymous: Optional[bool] = None, |
|
project: Optional[str] = None, |
|
log_model: Union[Literal["all"], bool] = False, |
|
experiment: Union["Run", "RunDisabled", None] = None, |
|
prefix: str = "", |
|
checkpoint_name: Optional[str] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__( |
|
name=name, |
|
save_dir=save_dir, |
|
version=version, |
|
offline=False, |
|
dir=dir, |
|
id=id, |
|
anonymous=anonymous, |
|
project=project, |
|
log_model=log_model, |
|
experiment=experiment, |
|
prefix=prefix, |
|
checkpoint_name=checkpoint_name, |
|
**kwargs, |
|
) |
|
self._offline = offline |
|
communication_dir = Path(".wandb_osh_command_dir") |
|
communication_dir.mkdir(parents=True, exist_ok=True) |
|
self.trigger_sync = TriggerWandbSyncHook(communication_dir) |
|
self.last_sync_time = 0.0 |
|
self.min_sync_interval = 60 |
|
self.wandb_dir = os.path.join(self._save_dir, "wandb/latest-run") |
|
|
|
@override |
|
@rank_zero_only |
|
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: |
|
out = super().log_metrics(metrics, step) |
|
if time.time() - self.last_sync_time > self.min_sync_interval: |
|
self.trigger_sync(self.wandb_dir) |
|
self.last_sync_time = time.time() |
|
return out |
|
|