|
from pathlib import Path |
|
import wandb |
|
|
|
|
|
def is_run_id(run_id: str) -> bool: |
|
"""Check if a string is a run ID.""" |
|
return len(run_id) == 8 and run_id.isalnum() |
|
|
|
|
|
def version_to_int(artifact) -> int: |
|
"""Convert versions of the form vX to X. For example, v12 to 12.""" |
|
return int(artifact.version[1:]) |
|
|
|
|
|
def download_latest_checkpoint(run_path: str, download_dir: Path) -> Path: |
|
api = wandb.Api() |
|
run = api.run(run_path) |
|
|
|
|
|
latest = None |
|
for artifact in run.logged_artifacts(): |
|
if artifact.type != "model" or artifact.state != "COMMITTED": |
|
continue |
|
|
|
if latest is None or version_to_int(artifact) > version_to_int(latest): |
|
latest = artifact |
|
|
|
|
|
download_dir.mkdir(exist_ok=True, parents=True) |
|
root = download_dir / run_path |
|
latest.download(root=root) |
|
return root / "model.ckpt" |
|
|