|
|
|
|
|
|
|
|
|
|
|
""" |
|
OpenSeed Training Script based on MaskDINO. |
|
""" |
|
try: |
|
from shapely.errors import ShapelyDeprecationWarning |
|
import warnings |
|
warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning) |
|
except: |
|
pass |
|
|
|
import sys |
|
import copy |
|
import itertools |
|
import logging |
|
import os |
|
import time |
|
|
|
from collections import OrderedDict |
|
from typing import Any, Dict, List, Set |
|
from fvcore.nn.precise_bn import get_bn_modules |
|
|
|
import torch |
|
|
|
import detectron2.utils.comm as comm |
|
from detectron2.checkpoint import DetectionCheckpointer |
|
from detectron2.config import get_cfg, CfgNode |
|
from detectron2.data import MetadataCatalog, build_detection_train_loader |
|
|
|
from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler |
|
from detectron2.solver.build import maybe_add_gradient_clipping |
|
from detectron2.utils.logger import setup_logger |
|
from detectron2.config import LazyConfig, instantiate |
|
|
|
from utils.arguments import load_opt_command |
|
from detectron2.utils.comm import get_world_size, is_main_process |
|
|
|
|
|
|
|
from datasets import ( |
|
build_train_dataloader, |
|
build_evaluator, |
|
build_eval_dataloader, |
|
) |
|
import random |
|
from detectron2.engine import ( |
|
DefaultTrainer, |
|
default_argument_parser, |
|
default_setup, |
|
hooks, |
|
launch, |
|
create_ddp_model, |
|
AMPTrainer, |
|
SimpleTrainer |
|
) |
|
import weakref |
|
|
|
from openseed import build_model |
|
from openseed.BaseModel import BaseModel |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level = logging.INFO) |
|
|
|
|
|
class Trainer(DefaultTrainer): |
|
""" |
|
Extension of the Trainer class adapted to MaskFormer. |
|
""" |
|
def __init__(self, cfg): |
|
super(DefaultTrainer, self).__init__() |
|
logger = logging.getLogger("detectron2") |
|
if not logger.isEnabledFor(logging.INFO): |
|
setup_logger() |
|
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) |
|
|
|
|
|
model = self.build_model(cfg) |
|
optimizer = self.build_optimizer(cfg, model) |
|
data_loader = self.build_train_loader(cfg) |
|
|
|
model = create_ddp_model(model, broadcast_buffers=False) |
|
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( |
|
model, data_loader, optimizer |
|
) |
|
self.scheduler = self.build_lr_scheduler(cfg, optimizer) |
|
|
|
|
|
kwargs = { |
|
'trainer': weakref.proxy(self), |
|
} |
|
|
|
self.checkpointer = DetectionCheckpointer( |
|
|
|
model, |
|
cfg['OUTPUT_DIR'], |
|
**kwargs, |
|
) |
|
self.start_iter = 0 |
|
self.max_iter = cfg['SOLVER']['MAX_ITER'] |
|
self.cfg = cfg |
|
|
|
self.register_hooks(self.build_hooks()) |
|
|
|
self.checkpointer = DetectionCheckpointer( |
|
|
|
model, |
|
cfg['OUTPUT_DIR'], |
|
**kwargs, |
|
) |
|
|
|
|
|
def build_hooks(self): |
|
""" |
|
Build a list of default hooks, including timing, evaluation, |
|
checkpointing, lr scheduling, precise BN, writing events. |
|
|
|
Returns: |
|
list[HookBase]: |
|
""" |
|
cfg = copy.deepcopy(self.cfg) |
|
|
|
cfg.DATALOADER.NUM_WORKERS = 0 |
|
ret = [ |
|
hooks.IterationTimer(), |
|
hooks.LRScheduler(), |
|
None, |
|
] |
|
|
|
|
|
|
|
|
|
|
|
if comm.is_main_process(): |
|
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) |
|
|
|
def test_and_save_results(): |
|
self._last_eval_results = self.test(self.cfg, self.model) |
|
return self._last_eval_results |
|
|
|
|
|
|
|
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) |
|
|
|
if comm.is_main_process(): |
|
|
|
|
|
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) |
|
return ret |
|
|
|
@classmethod |
|
def build_model(cls, cfg): |
|
""" |
|
Returns: |
|
torch.nn.Module: |
|
|
|
It now calls :func:`detectron2.modeling.build_model`. |
|
Overwrite it if you'd like a different model. |
|
""" |
|
model = BaseModel(cfg, build_model(cfg)).cuda() |
|
logger = logging.getLogger(__name__) |
|
logger.info("Model:\n{}".format(model)) |
|
return model |
|
|
|
@classmethod |
|
def build_evaluator(cls, cfg, dataset_name, output_folder=None): |
|
return build_evaluator(cfg, dataset_name, output_folder=output_folder) |
|
|
|
@classmethod |
|
def build_train_loader(cls, cfg): |
|
return build_train_dataloader(cfg, ) |
|
|
|
@classmethod |
|
def build_test_loader(cls, cfg, dataset_name): |
|
loader = build_eval_dataloader(cfg, ) |
|
return loader |
|
|
|
@classmethod |
|
def build_lr_scheduler(cls, cfg, optimizer): |
|
""" |
|
It now calls :func:`detectron2.solver.build_lr_scheduler`. |
|
Overwrite it if you'd like a different scheduler. |
|
""" |
|
return build_lr_scheduler(cfg, optimizer) |
|
|
|
@classmethod |
|
def build_optimizer(cls, cfg, model): |
|
cfg_solver = cfg['SOLVER'] |
|
weight_decay_norm = cfg_solver['WEIGHT_DECAY_NORM'] |
|
weight_decay_embed = cfg_solver['WEIGHT_DECAY_EMBED'] |
|
weight_decay_bias = cfg_solver.get('WEIGHT_DECAY_BIAS', 0.0) |
|
|
|
defaults = {} |
|
defaults["lr"] = cfg_solver['BASE_LR'] |
|
defaults["weight_decay"] = cfg_solver['WEIGHT_DECAY'] |
|
|
|
norm_module_types = ( |
|
torch.nn.BatchNorm1d, |
|
torch.nn.BatchNorm2d, |
|
torch.nn.BatchNorm3d, |
|
torch.nn.SyncBatchNorm, |
|
|
|
torch.nn.GroupNorm, |
|
torch.nn.InstanceNorm1d, |
|
torch.nn.InstanceNorm2d, |
|
torch.nn.InstanceNorm3d, |
|
torch.nn.LayerNorm, |
|
torch.nn.LocalResponseNorm, |
|
) |
|
|
|
lr_multiplier = cfg['SOLVER']['LR_MULTIPLIER'] |
|
|
|
params: List[Dict[str, Any]] = [] |
|
memo: Set[torch.nn.parameter.Parameter] = set() |
|
for module_name, module in model.named_modules(): |
|
for module_param_name, value in module.named_parameters(recurse=False): |
|
if not value.requires_grad: |
|
continue |
|
|
|
if value in memo: |
|
continue |
|
memo.add(value) |
|
|
|
hyperparams = copy.copy(defaults) |
|
|
|
for key, lr_mul in lr_multiplier.items(): |
|
if key in "{}.{}".format(module_name, module_param_name): |
|
hyperparams["lr"] = hyperparams["lr"] * lr_mul |
|
if is_main_process(): |
|
logger.info("Modify Learning rate of {}: {}".format( |
|
"{}.{}".format(module_name, module_param_name), lr_mul)) |
|
|
|
if ( |
|
"relative_position_bias_table" in module_param_name |
|
or "absolute_pos_embed" in module_param_name |
|
): |
|
hyperparams["weight_decay"] = 0.0 |
|
if isinstance(module, norm_module_types): |
|
hyperparams["weight_decay"] = weight_decay_norm |
|
if isinstance(module, torch.nn.Embedding): |
|
hyperparams["weight_decay"] = weight_decay_embed |
|
if "bias" in module_name: |
|
hyperparams["weight_decay"] = weight_decay_bias |
|
params.append({"params": [value], **hyperparams}) |
|
|
|
def maybe_add_full_model_gradient_clipping(optim): |
|
|
|
clip_norm_val = cfg_solver['CLIP_GRADIENTS']['CLIP_VALUE'] |
|
enable = ( |
|
cfg_solver['CLIP_GRADIENTS']['ENABLED'] |
|
and cfg_solver['CLIP_GRADIENTS']['CLIP_TYPE'] == "full_model" |
|
and clip_norm_val > 0.0 |
|
) |
|
|
|
class FullModelGradientClippingOptimizer(optim): |
|
def step(self, closure=None): |
|
all_params = itertools.chain(*[x["params"] for x in self.param_groups]) |
|
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) |
|
super().step(closure=closure) |
|
|
|
return FullModelGradientClippingOptimizer if enable else optim |
|
|
|
optimizer_type = cfg_solver['OPTIMIZER'] |
|
if optimizer_type == "SGD": |
|
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( |
|
params, cfg_solver['BASE_LR'], momentum=cfg_solver['MOMENTUM'] |
|
) |
|
elif optimizer_type == "ADAMW": |
|
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( |
|
params, cfg_solver['BASE_LR'] |
|
) |
|
else: |
|
raise NotImplementedError(f"no optimizer type {optimizer_type}") |
|
return optimizer |
|
|
|
@staticmethod |
|
def auto_scale_workers(cfg, num_workers: int): |
|
""" |
|
Returns: |
|
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``. |
|
""" |
|
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE |
|
if old_world_size == 0 or old_world_size == num_workers: |
|
return cfg |
|
cfg = copy.deepcopy(cfg) |
|
|
|
|
|
|
|
assert ( |
|
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0 |
|
), "Invalid REFERENCE_WORLD_SIZE in config!" |
|
scale = num_workers / old_world_size |
|
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale)) |
|
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale |
|
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale)) |
|
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale)) |
|
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS) |
|
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale)) |
|
cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale)) |
|
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers |
|
logger = logging.getLogger(__name__) |
|
logger.info( |
|
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, " |
|
f"max_iter={max_iter}, warmup={warmup_iter}." |
|
) |
|
return cfg |
|
|
|
@classmethod |
|
def test(cls, cfg, model, evaluators=None): |
|
from utils.misc import hook_metadata, hook_switcher, hook_opt |
|
from openseed.utils import get_class_names |
|
from detectron2.utils.logger import log_every_n_seconds |
|
import datetime |
|
|
|
dataloaders = cls.build_test_loader(cfg, dataset_name=None) |
|
dataset_names = cfg['DATASETS']['TEST'] |
|
model = model.eval().cuda() |
|
model_without_ddp = model |
|
if not type(model) == BaseModel: |
|
model_without_ddp = model.module |
|
|
|
for dataloader, dataset_name in zip(dataloaders, dataset_names): |
|
|
|
evaluator = build_evaluator(cfg, dataset_name, cfg['OUTPUT_DIR']) |
|
evaluator.reset() |
|
with torch.no_grad(): |
|
|
|
names = get_class_names(dataset_name, cfg['MODEL'].get('BACKGROUND', True)) |
|
|
|
model_without_ddp.model.metadata = MetadataCatalog.get(dataset_name) |
|
eval_type = model_without_ddp.model.metadata.evaluator_type |
|
if 'background' in names: |
|
model_without_ddp.model.sem_seg_head.num_classes = len(names) - 1 |
|
else: |
|
model_without_ddp.model.sem_seg_head.num_classes = len(names) |
|
model_without_ddp.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(names, is_eval=True) |
|
hook_switcher(model_without_ddp, dataset_name) |
|
|
|
|
|
|
|
task = 'seg' |
|
|
|
|
|
total = len(dataloader) |
|
num_warmup = min(5, total - 1) |
|
start_time = time.perf_counter() |
|
total_data_time = 0 |
|
total_compute_time = 0 |
|
total_eval_time = 0 |
|
start_data_time = time.perf_counter() |
|
|
|
for idx, batch in enumerate(dataloader): |
|
total_data_time += time.perf_counter() - start_data_time |
|
if idx == num_warmup: |
|
start_time = time.perf_counter() |
|
total_data_time = 0 |
|
total_compute_time = 0 |
|
total_eval_time = 0 |
|
start_compute_time = time.perf_counter() |
|
|
|
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16): |
|
|
|
outputs = model(batch, inference_task=task) |
|
|
|
total_compute_time += time.perf_counter() - start_compute_time |
|
start_eval_time = time.perf_counter() |
|
|
|
evaluator.process(batch, outputs) |
|
total_eval_time += time.perf_counter() - start_eval_time |
|
|
|
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) |
|
data_seconds_per_iter = total_data_time / iters_after_start |
|
compute_seconds_per_iter = total_compute_time / iters_after_start |
|
eval_seconds_per_iter = total_eval_time / iters_after_start |
|
total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start |
|
|
|
if is_main_process() and (idx >= num_warmup * 2 or compute_seconds_per_iter > 5): |
|
eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1))) |
|
log_every_n_seconds( |
|
logging.INFO, |
|
( |
|
f"Inference done {idx + 1}/{total}. " |
|
f"Dataloading: {data_seconds_per_iter:.4f} s/iter. " |
|
f"Inference: {compute_seconds_per_iter:.4f} s/iter. " |
|
f"Eval: {eval_seconds_per_iter:.4f} s/iter. " |
|
f"Total: {total_seconds_per_iter:.4f} s/iter. " |
|
f"ETA={eta}" |
|
), |
|
n=5, |
|
) |
|
start_data_time = time.perf_counter() |
|
|
|
|
|
results = evaluator.evaluate() |
|
|
|
model = model.train().cuda() |
|
|
|
|
|
def setup(args): |
|
""" |
|
Create configs and perform basic setups. |
|
""" |
|
cfg = get_cfg() |
|
cfg = LazyConfig.load(args.config_file) |
|
cfg = LazyConfig.apply_overrides(cfg, args.opts) |
|
default_setup(cfg, args) |
|
setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino") |
|
return cfg |
|
|
|
|
|
def main(args=None): |
|
cfg = setup(args) |
|
print("Command cfg:", cfg) |
|
if args.eval_only: |
|
model = Trainer.build_model(cfg) |
|
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( |
|
cfg.MODEL.WEIGHTS, resume=args.resume |
|
) |
|
if args.original_load: |
|
print("using original loading") |
|
model = model.from_pretrained(cfg.MODEL.WEIGHTS) |
|
res = Trainer.test(cfg, model) |
|
if cfg.TEST.AUG.ENABLED: |
|
res.update(Trainer.test_with_TTA(cfg, model)) |
|
|
|
return res |
|
|
|
trainer = Trainer(cfg) |
|
if len(args.lang_weight) > 0: |
|
|
|
import copy |
|
weight = copy.deepcopy(trainer.cfg.MODEL.WEIGHTS) |
|
trainer.cfg.MODEL.WEIGHTS = args.lang_weight |
|
print("load original language language weight!!!!!!") |
|
|
|
trainer._trainer.model.module = trainer._trainer.model.module.from_pretrained(cfg.MODEL.WEIGHTS) |
|
trainer.cfg.MODEL.WEIGHTS = weight |
|
print("load pretrained model weight!!!!!!") |
|
trainer.resume_or_load(resume=args.resume) |
|
if args.original_load: |
|
print("using original loading") |
|
try: |
|
trainer._trainer.model.module = trainer._trainer.model.module.from_pretrained(cfg.MODEL.WEIGHTS) |
|
except Exception as e: |
|
trainer._trainer.model = trainer._trainer.model.from_pretrained(cfg.MODEL.WEIGHTS) |
|
return trainer.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = default_argument_parser() |
|
parser.add_argument('--eval_only', action='store_true') |
|
parser.add_argument('--original_load', action='store_true') |
|
parser.add_argument('--lang_weight', type=str, default='') |
|
parser.add_argument('--EVAL_FLAG', type=int, default=1) |
|
args = parser.parse_args() |
|
port = random.randint(1000, 20000) |
|
args.dist_url = 'tcp://127.0.0.1:' + str(port) |
|
print("Command Line Args:", args) |
|
print("pwd:", os.getcwd()) |
|
launch( |
|
main, |
|
args.num_gpus, |
|
num_machines=args.num_machines, |
|
machine_rank=args.machine_rank, |
|
dist_url=args.dist_url, |
|
args=(args,), |
|
) |
|
|