File size: 933 Bytes
27ca8b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from pathlib import Path
import wandb
def is_run_id(run_id: str) -> bool:
"""Check if a string is a run ID."""
return len(run_id) == 8 and run_id.isalnum()
def version_to_int(artifact) -> int:
"""Convert versions of the form vX to X. For example, v12 to 12."""
return int(artifact.version[1:])
def download_latest_checkpoint(run_path: str, download_dir: Path) -> Path:
api = wandb.Api()
run = api.run(run_path)
# Find the latest saved model checkpoint.
latest = None
for artifact in run.logged_artifacts():
if artifact.type != "model" or artifact.state != "COMMITTED":
continue
if latest is None or version_to_int(artifact) > version_to_int(latest):
latest = artifact
# Download the checkpoint.
download_dir.mkdir(exist_ok=True, parents=True)
root = download_dir / run_path
latest.download(root=root)
return root / "model.ckpt"
|