Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| """ | |
| Training script using the new "LazyConfig" python config files. | |
| This scripts reads a given python config file and runs the training or evaluation. | |
| It can be used to train any models or dataset as long as they can be | |
| instantiated by the recursive construction defined in the given config file. | |
| Besides lazy construction of models, dataloader, etc., this scripts expects a | |
| few common configuration parameters currently defined in "configs/common/train.py". | |
| To add more complicated training logic, you can easily add other configs | |
| in the config file and implement a new train_net.py to handle them. | |
| """ | |
| import logging | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| from detectron2.config import LazyConfig, instantiate | |
| from detectron2.engine import ( | |
| AMPTrainer, | |
| SimpleTrainer, | |
| default_argument_parser, | |
| default_setup, | |
| default_writers, | |
| hooks, | |
| launch, | |
| ) | |
| from detectron2.engine.defaults import create_ddp_model | |
| from detectron2.evaluation import inference_on_dataset, print_csv_format | |
| from detectron2.utils import comm | |
| logger = logging.getLogger("detectron2") | |
| def do_test(cfg, model): | |
| if "evaluator" in cfg.dataloader: | |
| ret = inference_on_dataset( | |
| model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) | |
| ) | |
| print_csv_format(ret) | |
| return ret | |
| def do_train(args, cfg): | |
| """ | |
| Args: | |
| cfg: an object with the following attributes: | |
| model: instantiate to a module | |
| dataloader.{train,test}: instantiate to dataloaders | |
| dataloader.evaluator: instantiate to evaluator for test set | |
| optimizer: instantaite to an optimizer | |
| lr_multiplier: instantiate to a fvcore scheduler | |
| train: other misc config defined in `configs/common/train.py`, including: | |
| output_dir (str) | |
| init_checkpoint (str) | |
| amp.enabled (bool) | |
| max_iter (int) | |
| eval_period, log_period (int) | |
| device (str) | |
| checkpointer (dict) | |
| ddp (dict) | |
| """ | |
| model = instantiate(cfg.model) | |
| logger = logging.getLogger("detectron2") | |
| logger.info("Model:\n{}".format(model)) | |
| model.to(cfg.train.device) | |
| cfg.optimizer.params.model = model | |
| optim = instantiate(cfg.optimizer) | |
| train_loader = instantiate(cfg.dataloader.train) | |
| model = create_ddp_model(model, **cfg.train.ddp) | |
| trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim) | |
| checkpointer = DetectionCheckpointer( | |
| model, | |
| cfg.train.output_dir, | |
| optimizer=optim, | |
| trainer=trainer, | |
| ) | |
| trainer.register_hooks( | |
| [ | |
| hooks.IterationTimer(), | |
| hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), | |
| hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) | |
| if comm.is_main_process() | |
| else None, | |
| hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), | |
| hooks.PeriodicWriter( | |
| default_writers(cfg.train.output_dir, cfg.train.max_iter), | |
| period=cfg.train.log_period, | |
| ) | |
| if comm.is_main_process() | |
| else None, | |
| ] | |
| ) | |
| checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) | |
| if args.resume and checkpointer.has_checkpoint(): | |
| # The checkpoint stores the training iteration that just finished, thus we start | |
| # at the next iteration | |
| start_iter = trainer.iter + 1 | |
| else: | |
| start_iter = 0 | |
| trainer.train(start_iter, cfg.train.max_iter) | |
| def main(args): | |
| cfg = LazyConfig.load(args.config_file) | |
| cfg = LazyConfig.apply_overrides(cfg, args.opts) | |
| default_setup(cfg, args) | |
| if args.eval_only: | |
| model = instantiate(cfg.model) | |
| model.to(cfg.train.device) | |
| model = create_ddp_model(model) | |
| DetectionCheckpointer(model).load(cfg.train.init_checkpoint) | |
| print(do_test(cfg, model)) | |
| else: | |
| do_train(args, cfg) | |
| if __name__ == "__main__": | |
| args = default_argument_parser().parse_args() | |
| launch( | |
| main, | |
| args.num_gpus, | |
| num_machines=args.num_machines, | |
| machine_rank=args.machine_rank, | |
| dist_url=args.dist_url, | |
| args=(args,), | |
| ) | |