worldmem / experiments /__init__.py
xizaoqu
init
27ca8b3
raw
history blame
1.23 kB
from typing import Optional, Union
from omegaconf import DictConfig
import pathlib
from lightning.pytorch.loggers.wandb import WandbLogger
from .exp_base import BaseExperiment
from .exp_video import VideoPredictionExperiment
from .exp_pose import PoseExperiment
# each key has to be a yaml file under '[project_root]/configurations/experiment' without .yaml suffix
exp_registry = dict(
exp_video=VideoPredictionExperiment,
exp_pose=PoseExperiment
)
def build_experiment(
cfg: DictConfig,
logger: Optional[WandbLogger] = None,
ckpt_path: Optional[Union[str, pathlib.Path]] = None,
) -> BaseExperiment:
"""
Build an experiment instance based on registry
:param cfg: configuration file
:param logger: optional logger for the experiment
:param ckpt_path: optional checkpoint path for saving and loading
:return:
"""
if cfg.experiment._name not in exp_registry:
raise ValueError(
f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. "
"Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file."
)
return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path)