File size: 1,234 Bytes
27ca8b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from typing import Optional, Union
from omegaconf import DictConfig
import pathlib
from lightning.pytorch.loggers.wandb import WandbLogger
from .exp_base import BaseExperiment
from .exp_video import VideoPredictionExperiment
from .exp_pose import PoseExperiment
# each key has to be a yaml file under '[project_root]/configurations/experiment' without .yaml suffix
exp_registry = dict(
exp_video=VideoPredictionExperiment,
exp_pose=PoseExperiment
)
def build_experiment(
cfg: DictConfig,
logger: Optional[WandbLogger] = None,
ckpt_path: Optional[Union[str, pathlib.Path]] = None,
) -> BaseExperiment:
"""
Build an experiment instance based on registry
:param cfg: configuration file
:param logger: optional logger for the experiment
:param ckpt_path: optional checkpoint path for saving and loading
:return:
"""
if cfg.experiment._name not in exp_registry:
raise ValueError(
f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. "
"Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file."
)
return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path)
|