|
from typing import Optional, Union |
|
from omegaconf import DictConfig |
|
import pathlib |
|
from lightning.pytorch.loggers.wandb import WandbLogger |
|
|
|
from .exp_base import BaseExperiment |
|
from .exp_video import VideoPredictionExperiment |
|
from .exp_pose import PoseExperiment |
|
|
|
|
|
exp_registry = dict( |
|
exp_video=VideoPredictionExperiment, |
|
exp_pose=PoseExperiment |
|
) |
|
|
|
|
|
def build_experiment( |
|
cfg: DictConfig, |
|
logger: Optional[WandbLogger] = None, |
|
ckpt_path: Optional[Union[str, pathlib.Path]] = None, |
|
) -> BaseExperiment: |
|
""" |
|
Build an experiment instance based on registry |
|
:param cfg: configuration file |
|
:param logger: optional logger for the experiment |
|
:param ckpt_path: optional checkpoint path for saving and loading |
|
:return: |
|
""" |
|
if cfg.experiment._name not in exp_registry: |
|
raise ValueError( |
|
f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. " |
|
"Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file." |
|
) |
|
|
|
return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path) |
|
|