worldmem / utils /ckpt_utils.py
xizaoqu
init
27ca8b3
raw
history blame
933 Bytes
from pathlib import Path
import wandb
def is_run_id(run_id: str) -> bool:
"""Check if a string is a run ID."""
return len(run_id) == 8 and run_id.isalnum()
def version_to_int(artifact) -> int:
"""Convert versions of the form vX to X. For example, v12 to 12."""
return int(artifact.version[1:])
def download_latest_checkpoint(run_path: str, download_dir: Path) -> Path:
api = wandb.Api()
run = api.run(run_path)
# Find the latest saved model checkpoint.
latest = None
for artifact in run.logged_artifacts():
if artifact.type != "model" or artifact.state != "COMMITTED":
continue
if latest is None or version_to_int(artifact) > version_to_int(latest):
latest = artifact
# Download the checkpoint.
download_dir.mkdir(exist_ok=True, parents=True)
root = download_dir / run_path
latest.download(root=root)
return root / "model.ckpt"