vlm_clone_2 / OpenSeeD /train_net.py
tuandunghcmut's picture
Add files using upload-large-folder tool
4617407 verified
raw
history blame
18.7 kB
# ------------------------------------------------------------------------
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# by Feng Li and Hao Zhang.
# ------------------------------------------------------------------------
"""
OpenSeed Training Script based on MaskDINO.
"""
try:
from shapely.errors import ShapelyDeprecationWarning
import warnings
warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
except:
pass
import sys
import copy
import itertools
import logging
import os
import time
from collections import OrderedDict
from typing import Any, Dict, List, Set
from fvcore.nn.precise_bn import get_bn_modules
import torch
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg, CfgNode
from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.utils.logger import setup_logger
from detectron2.config import LazyConfig, instantiate
from utils.arguments import load_opt_command
from detectron2.utils.comm import get_world_size, is_main_process
# MaskDINO
from datasets import (
build_train_dataloader,
build_evaluator,
build_eval_dataloader,
)
import random
from detectron2.engine import (
DefaultTrainer,
default_argument_parser,
default_setup,
hooks,
launch,
create_ddp_model,
AMPTrainer,
SimpleTrainer
)
import weakref
from openseed import build_model
from openseed.BaseModel import BaseModel
logger = logging.getLogger(__name__)
logging.basicConfig(level = logging.INFO)
class Trainer(DefaultTrainer):
"""
Extension of the Trainer class adapted to MaskFormer.
"""
def __init__(self, cfg):
super(DefaultTrainer, self).__init__()
logger = logging.getLogger("detectron2")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
model = create_ddp_model(model, broadcast_buffers=False)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# add model EMA
kwargs = {
'trainer': weakref.proxy(self),
}
# kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model)) TODO: release ema training for large models
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg['OUTPUT_DIR'],
**kwargs,
)
self.start_iter = 0
self.max_iter = cfg['SOLVER']['MAX_ITER']
self.cfg = cfg
self.register_hooks(self.build_hooks())
# TODO: release model conversion checkpointer from DINO to MaskDINO
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg['OUTPUT_DIR'],
**kwargs,
)
# TODO: release GPU cluster submit scripts based on submitit for multi-node training
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = copy.deepcopy(self.cfg)
# cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(),
None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
if comm.is_main_process():
# Here the default print/log frequency of each writer is used.
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = BaseModel(cfg, build_model(cfg)).cuda()
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
return model
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
return build_evaluator(cfg, dataset_name, output_folder=output_folder)
@classmethod
def build_train_loader(cls, cfg):
return build_train_dataloader(cfg, )
@classmethod
def build_test_loader(cls, cfg, dataset_name):
loader = build_eval_dataloader(cfg, )
return loader
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)
@classmethod
def build_optimizer(cls, cfg, model):
cfg_solver = cfg['SOLVER']
weight_decay_norm = cfg_solver['WEIGHT_DECAY_NORM']
weight_decay_embed = cfg_solver['WEIGHT_DECAY_EMBED']
weight_decay_bias = cfg_solver.get('WEIGHT_DECAY_BIAS', 0.0)
defaults = {}
defaults["lr"] = cfg_solver['BASE_LR']
defaults["weight_decay"] = cfg_solver['WEIGHT_DECAY']
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
lr_multiplier = cfg['SOLVER']['LR_MULTIPLIER']
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
for key, lr_mul in lr_multiplier.items():
if key in "{}.{}".format(module_name, module_param_name):
hyperparams["lr"] = hyperparams["lr"] * lr_mul
if is_main_process():
logger.info("Modify Learning rate of {}: {}".format(
"{}.{}".format(module_name, module_param_name), lr_mul))
if (
"relative_position_bias_table" in module_param_name
or "absolute_pos_embed" in module_param_name
):
hyperparams["weight_decay"] = 0.0
if isinstance(module, norm_module_types):
hyperparams["weight_decay"] = weight_decay_norm
if isinstance(module, torch.nn.Embedding):
hyperparams["weight_decay"] = weight_decay_embed
if "bias" in module_name:
hyperparams["weight_decay"] = weight_decay_bias
params.append({"params": [value], **hyperparams})
def maybe_add_full_model_gradient_clipping(optim):
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg_solver['CLIP_GRADIENTS']['CLIP_VALUE']
enable = (
cfg_solver['CLIP_GRADIENTS']['ENABLED']
and cfg_solver['CLIP_GRADIENTS']['CLIP_TYPE'] == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg_solver['OPTIMIZER']
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg_solver['BASE_LR'], momentum=cfg_solver['MOMENTUM']
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg_solver['BASE_LR']
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
return optimizer
@staticmethod
def auto_scale_workers(cfg, num_workers: int):
"""
Returns:
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
"""
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
if old_world_size == 0 or old_world_size == num_workers:
return cfg
cfg = copy.deepcopy(cfg)
# frozen = cfg.is_frozen()
# cfg.defrost()
assert (
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
), "Invalid REFERENCE_WORLD_SIZE in config!"
scale = num_workers / old_world_size
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
logger = logging.getLogger(__name__)
logger.info(
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
f"max_iter={max_iter}, warmup={warmup_iter}."
)
return cfg
@classmethod
def test(cls, cfg, model, evaluators=None):
from utils.misc import hook_metadata, hook_switcher, hook_opt
from openseed.utils import get_class_names
from detectron2.utils.logger import log_every_n_seconds
import datetime
# build dataloade
dataloaders = cls.build_test_loader(cfg, dataset_name=None)
dataset_names = cfg['DATASETS']['TEST']
model = model.eval().cuda()
model_without_ddp = model
if not type(model) == BaseModel:
model_without_ddp = model.module
for dataloader, dataset_name in zip(dataloaders, dataset_names):
# build evaluator
evaluator = build_evaluator(cfg, dataset_name, cfg['OUTPUT_DIR'])
evaluator.reset()
with torch.no_grad():
# setup model
names = get_class_names(dataset_name, cfg['MODEL'].get('BACKGROUND', True))
# names = get_class_names(dataset_name)
model_without_ddp.model.metadata = MetadataCatalog.get(dataset_name)
eval_type = model_without_ddp.model.metadata.evaluator_type
if 'background' in names:
model_without_ddp.model.sem_seg_head.num_classes = len(names) - 1
else:
model_without_ddp.model.sem_seg_head.num_classes = len(names)
model_without_ddp.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(names, is_eval=True)
hook_switcher(model_without_ddp, dataset_name)
# hook_opt(model, dataset_name)
# setup task
task = 'seg'
# setup timer
total = len(dataloader)
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_data_time = 0
total_compute_time = 0
total_eval_time = 0
start_data_time = time.perf_counter()
for idx, batch in enumerate(dataloader):
total_data_time += time.perf_counter() - start_data_time
if idx == num_warmup:
start_time = time.perf_counter()
total_data_time = 0
total_compute_time = 0
total_eval_time = 0
start_compute_time = time.perf_counter()
# forward
with torch.autocast(device_type='cuda', dtype=torch.float16):
# import ipdb; ipdb.set_trace()
outputs = model(batch, inference_task=task)
total_compute_time += time.perf_counter() - start_compute_time
start_eval_time = time.perf_counter()
evaluator.process(batch, outputs)
total_eval_time += time.perf_counter() - start_eval_time
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
data_seconds_per_iter = total_data_time / iters_after_start
compute_seconds_per_iter = total_compute_time / iters_after_start
eval_seconds_per_iter = total_eval_time / iters_after_start
total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
if is_main_process() and (idx >= num_warmup * 2 or compute_seconds_per_iter > 5):
eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
(
f"Inference done {idx + 1}/{total}. "
f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
f"Total: {total_seconds_per_iter:.4f} s/iter. "
f"ETA={eta}"
),
n=5,
)
start_data_time = time.perf_counter()
# evaluate
results = evaluator.evaluate()
model = model.train().cuda()
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.apply_overrides(cfg, args.opts)
default_setup(cfg, args)
setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino")
return cfg
def main(args=None):
cfg = setup(args)
print("Command cfg:", cfg)
if args.eval_only:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
if args.original_load:
print("using original loading")
model = model.from_pretrained(cfg.MODEL.WEIGHTS)
res = Trainer.test(cfg, model)
if cfg.TEST.AUG.ENABLED:
res.update(Trainer.test_with_TTA(cfg, model))
return res
trainer = Trainer(cfg)
if len(args.lang_weight) > 0:
# load language weight for semantic
import copy
weight = copy.deepcopy(trainer.cfg.MODEL.WEIGHTS)
trainer.cfg.MODEL.WEIGHTS = args.lang_weight
print("load original language language weight!!!!!!")
# trainer.resume_or_load(resume=args.resume)
trainer._trainer.model.module = trainer._trainer.model.module.from_pretrained(cfg.MODEL.WEIGHTS)
trainer.cfg.MODEL.WEIGHTS = weight
print("load pretrained model weight!!!!!!")
trainer.resume_or_load(resume=args.resume)
if args.original_load: # loading checkpoints with different name prefix
print("using original loading")
try:
trainer._trainer.model.module = trainer._trainer.model.module.from_pretrained(cfg.MODEL.WEIGHTS)
except Exception as e: # for debugging
trainer._trainer.model = trainer._trainer.model.from_pretrained(cfg.MODEL.WEIGHTS)
return trainer.train()
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument('--eval_only', action='store_true')
parser.add_argument('--original_load', action='store_true')
parser.add_argument('--lang_weight', type=str, default='')
parser.add_argument('--EVAL_FLAG', type=int, default=1)
args = parser.parse_args()
port = random.randint(1000, 20000)
args.dist_url = 'tcp://127.0.0.1:' + str(port)
print("Command Line Args:", args)
print("pwd:", os.getcwd())
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)