import os.path as osp import logging import warnings import sys from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import LazyConfig, instantiate from detectron2.engine import ( default_argument_parser, default_setup, default_writers, hooks, launch, ) from detectron2.engine.defaults import create_ddp_model from detectron2.utils import comm sys.path.append(osp.dirname(osp.dirname(__file__))) warnings.filterwarnings("ignore") logger = logging.getLogger("detectron2") from engine import CycleTrainer def do_train(args, cfg): """ Args: cfg: an object with the following attributes: model: instantiate to a module dataloader.{train,test}: instantiate to dataloaders dataloader.evaluator: instantiate to evaluator for test set optimizer: instantaite to an optimizer lr_multiplier: instantiate to a fvcore scheduler train: other misc config defined in `configs/common/train.py`, including: output_dir (str) init_checkpoint (str) amp.enabled (bool) max_iter (int) eval_period, log_period (int) device (str) checkpointer (dict) ddp (dict) """ model = instantiate(cfg.model) logger = logging.getLogger("detectron2") logger.info("Model:\n{}".format(model)) model.to(cfg.train.device) cfg.optimizer.params.model = model optim = instantiate(cfg.optimizer) train_loader = instantiate(cfg.dataloader.train) model = create_ddp_model(model, **cfg.train.ddp) trainer = CycleTrainer(model, train_loader, optim) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, trainer=trainer, ) trainer.register_hooks( [ hooks.IterationTimer(), hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None, # hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), hooks.PeriodicWriter( default_writers(cfg.train.output_dir, cfg.train.max_iter), period=cfg.train.log_period, ) if comm.is_main_process() else None, ] ) checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) if args.resume and checkpointer.has_checkpoint(): start_iter = trainer.iter + 1 else: start_iter = 0 trainer.train(start_iter, cfg.train.max_iter) def main(args): cfg = LazyConfig.load(args.config_file) cfg = LazyConfig.apply_overrides(cfg, args.opts) default_setup(cfg, args) do_train(args, cfg) if __name__ == "__main__": args = default_argument_parser().parse_args() launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), )