|
""" |
|
train.py - GLaMM Model Training on Mixed Datasets |
|
|
|
Trains the GLaMM model using Caption, Region, and Segmentation datasets with a random sampling approach. This method |
|
is crucial for developing a versatile model capable of handling diverse applications effectively. |
|
""" |
|
import os |
|
import sys |
|
import time |
|
import tqdm |
|
import random |
|
import torch |
|
import argparse |
|
import deepspeed |
|
import numpy as np |
|
import transformers |
|
from functools import partial |
|
from torch.utils.data import ConcatDataset |
|
from peft import LoraConfig, get_peft_model |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from model.GLaMM import GLaMMForCausalLM |
|
from model.llava import conversation as conversation_lib |
|
|
|
from dataset.dataset import custom_collate_fn, HybridSegDataset, HybridRegDataset, HybridCapDataset |
|
from tools.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, AverageMeter, ProgressMeter, dict_to_cuda, |
|
Summary, intersectionAndUnionGPU) |
|
|
|
from dataset.segm_datasets.RefCOCO_Segm_ds import ReferSegmDataset |
|
from dataset.region_datasets.RefCOCO_VG_Region_ds import RefCocoGRegDataset, VisualGenomeRegDataset |
|
from dataset.caption_datasets.COCO_Caption_ds import CocoCapDataset |
|
from dataset.gcg_datasets.GranDf_gcg_ds import OpenPsgGCGDataset, Flickr30kGCGDataset, RefCOCOgGCGDataset |
|
|
|
|
|
def parse_args(args): |
|
parser = argparse.ArgumentParser(description="GLaMM Model Training") |
|
|
|
|
|
parser.add_argument("--version", default="MBZUAI/GLaMM-GranD-Pretrained") |
|
parser.add_argument("--vision_pretrained", default="./checkpoints/sam_vit_h_4b8939.pth", type=str) |
|
parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14-336", type=str) |
|
parser.add_argument("--conv_type", default="llava_v1", type=str, choices=["llava_v1", "llava_llama_2"]) |
|
parser.add_argument("--tune_mm_mlp_adapter", action="store_true") |
|
parser.add_argument("--freeze_mm_mlp_adapter", action="store_true") |
|
parser.add_argument("--mm_use_im_start_end", action="store_true", default=True) |
|
parser.add_argument("--out_dim", default=256, type=int) |
|
parser.add_argument("--image_size", default=1024, type=int, help="Image size for grounding image encoder") |
|
parser.add_argument("--model_max_length", default=1536, type=int) |
|
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) |
|
parser.add_argument("--with_region", action="store_true", default=True) |
|
parser.add_argument("--mm_vision_select_layer", default=-2, type=int) |
|
parser.add_argument("--pretrain_mm_mlp_adapter", default="", type=str) |
|
parser.add_argument("--precision", default='bf16', type=str) |
|
|
|
|
|
parser.add_argument("--use_cap_data", action="store_true", help="Use caption data") |
|
parser.add_argument("--use_reg_data", action="store_true", help="Use region data") |
|
parser.add_argument("--use_segm_data", action="store_true", help="Use segmentation data") |
|
parser.add_argument("--weight_cap", default=0.15, type=float, help="Sampling weight for caption data") |
|
parser.add_argument("--weight_reg", default=0.40, type=float, help="Sampling weight for region data") |
|
parser.add_argument("--weight_segm", default=0.45, type=float, help="Sampling weight for segmentation data") |
|
parser.add_argument("--dataset_dir", default="./data", type=str) |
|
parser.add_argument("--seg_dataset", default="Semantic_Segm||Refer_Segm||RefCoco_GCG||PSG_GCG||Flickr_GCG||GranDf_GCG", |
|
type=str, help="Choose from: Semantic_Segm, Refer_Segm, RefCoco_GCG, GranDf_GCG, PSG_GCG, Flickr_GCG, GrandRefer_Segm") |
|
parser.add_argument("--segm_sample_rates", default="5,4,3,3,3,1", type=str) |
|
parser.add_argument("--reg_dataset", default="RefCoco_Reg||RefCocoG_Reg||RefCocoP_Reg||VisGen_Reg", |
|
type=str, help="Choose from: RefCoco_Reg, RefCocoG_Reg, RefCocoP_Reg, VisGen_Reg, Flickr_Reg, GrandRefer_Reg") |
|
parser.add_argument("--reg_sample_rates", default="1,1,1,1", type=str) |
|
parser.add_argument("--cap_dataset", default="CocoCap||LLaVaInstruct", type=str, |
|
help="Choose from: CocoCap, LLaVaInstruct, GrandCaptionDataset") |
|
parser.add_argument("--cap_sample_rates", default="1,1", type=str) |
|
parser.add_argument("--semantic_segm_data", default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary", type=str) |
|
parser.add_argument("--refer_segm_data", default="refcoco||refcoco+||refcocog||refclef", type=str) |
|
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str) |
|
parser.add_argument("--num_classes_per_sample", default=3, type=int) |
|
|
|
|
|
parser.add_argument("--pretrained", action="store_true") |
|
parser.add_argument("--resume", default="", type=str) |
|
parser.add_argument("--auto_resume", action="store_true") |
|
parser.add_argument("--weight", default="", type=str) |
|
parser.add_argument("--lr", default=0.0003, type=float) |
|
parser.add_argument("--epochs", default=10, type=int) |
|
parser.add_argument("--steps_per_epoch", default=500, type=int) |
|
parser.add_argument("--batch_size", default=2, type=int, help="batch size per device per step") |
|
parser.add_argument("--grad_accumulation_steps", default=10, type=int) |
|
parser.add_argument("--val_batch_size", default=1, type=int) |
|
parser.add_argument("--workers", default=2, type=int) |
|
parser.add_argument("--lora_r", default=8, type=int) |
|
parser.add_argument("--lora_alpha", default=16, type=int) |
|
parser.add_argument("--lora_dropout", default=0.05, type=float) |
|
parser.add_argument("--ce_loss_weight", default=1.0, type=float) |
|
parser.add_argument("--dice_loss_weight", default=0.5, type=float) |
|
parser.add_argument("--bce_loss_weight", default=2.0, type=float) |
|
parser.add_argument("--beta1", default=0.9, type=float) |
|
parser.add_argument("--beta2", default=0.95, type=float) |
|
parser.add_argument("--gradient_checkpointing", action="store_true", default=True) |
|
parser.add_argument("--train_mask_decoder", action="store_true", default=True) |
|
parser.add_argument("--use_mm_start_end", action="store_true", default=True) |
|
parser.add_argument("--print_freq", default=1, type=int) |
|
parser.add_argument("--start_epoch", default=0, type=int) |
|
parser.add_argument("--local_rank", default=0, type=int, help="node rank") |
|
|
|
|
|
parser.add_argument("--val_dataset", default="CocoCapVal|RefCOCOgRegVal|RefCOCOgSegmVal", type=str, |
|
help="Choose from: CocoCapVal, RefCOCOgRegVal, VisGenomeRegVal, RefCOCOgSegmVal, PsgGCGVal, " |
|
"RefCocoGCGVal, FlickrGCGVal") |
|
parser.add_argument("--mask_validation", action="store_true") |
|
parser.add_argument("--no_eval", action="store_true") |
|
parser.add_argument("--eval_only", action="store_true") |
|
|
|
|
|
parser.add_argument("--log_base_dir", default="./output", type=str) |
|
parser.add_argument("--exp_name", default="GlamFinetuneOS", type=str) |
|
|
|
return parser.parse_args(args) |
|
|
|
|
|
def initialize_environment(args): |
|
""" Set up logging and model directories. """ |
|
args.log_dir = os.path.join(args.log_base_dir, args.exp_name) |
|
if args.local_rank == 0: |
|
os.makedirs(args.log_dir, exist_ok=True) |
|
return SummaryWriter(args.log_dir) |
|
return None |
|
|
|
|
|
def setup_tokenizer_and_special_tokens(args): |
|
""" Load tokenizer and add special tokens. """ |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
args.version, model_max_length=args.model_max_length, padding_side="right", use_fast=False |
|
) |
|
print('\033[92m' + "---- Initialized tokenizer from: {} ----".format(args.version) + '\033[0m') |
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
|
if not args.pretrained: |
|
if args.use_mm_start_end: |
|
tokenizer.add_tokens( |
|
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True |
|
) |
|
|
|
reg_tokens = ['<bbox>', '<point>'] |
|
|
|
segmentation_tokens = ['[SEG]'] |
|
|
|
phrase_tokens = ['<p>', '</p>'] |
|
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens |
|
tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
|
args.bbox_token_idx = tokenizer("<bbox>", add_special_tokens=False).input_ids[0] |
|
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
|
args.bop_token_idx = tokenizer("<p>", add_special_tokens=False).input_ids[0] |
|
args.eop_token_idx = tokenizer("</p>", add_special_tokens=False).input_ids[0] |
|
|
|
return tokenizer |
|
|
|
|
|
def initialize_model(args, tokenizer): |
|
""" Initialize the GLaMM model. """ |
|
model_args = {k: getattr(args, k) for k in |
|
["train_mask_decoder", "out_dim", "ce_loss_weight", "dice_loss_weight", "bce_loss_weight", |
|
"seg_token_idx", "vision_pretrained", "vision_tower", "use_mm_start_end", "mm_vision_select_layer", |
|
"pretrain_mm_mlp_adapter", "tune_mm_mlp_adapter", "freeze_mm_mlp_adapter", "mm_use_im_start_end", |
|
"with_region", "bbox_token_idx", "eop_token_idx", "bop_token_idx"]} |
|
model_args["num_level_reg_features"] = 4 |
|
|
|
model = GLaMMForCausalLM.from_pretrained( |
|
args.version, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args |
|
) |
|
print('\033[92m' + "---- Initialized model from: {} ----".format(args.version) + '\033[0m') |
|
|
|
|
|
model.config.eos_token_id = tokenizer.eos_token_id |
|
model.config.bos_token_id = tokenizer.bos_token_id |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
return model |
|
|
|
|
|
def prepare_model_for_training(model, tokenizer, args): |
|
|
|
model.enable_input_require_grads() |
|
model.gradient_checkpointing_enable() |
|
|
|
|
|
print( |
|
'\033[92m' + "---- Initialized Global Image Encoder (vision tower) from: {} ----".format( |
|
args.vision_tower |
|
) + '\033[0m' |
|
) |
|
model.get_model().initialize_vision_modules(model.get_model().config) |
|
vision_tower = model.get_model().get_vision_tower() |
|
vision_tower.to(dtype=torch.bfloat16, device=args.local_rank) |
|
|
|
|
|
if not args.pretrained: |
|
model.get_model().initialize_glamm_model(model.get_model().config) |
|
else: |
|
for param in model.get_model().grounding_encoder.parameters(): |
|
param.requires_grad = False |
|
if model.get_model().config.train_mask_decoder: |
|
model.get_model().grounding_encoder.mask_decoder.train() |
|
for param in model.get_model().grounding_encoder.mask_decoder.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
model.get_model().text_hidden_fcs.train() |
|
for param in model.get_model().text_hidden_fcs.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
for p in vision_tower.parameters(): |
|
p.requires_grad = False |
|
for p in model.get_model().mm_projector.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
lora_r = args.lora_r |
|
if lora_r == 0: |
|
for p in model.get_model().layers.parameters(): |
|
p.requires_grad = True |
|
for p in model.get_model().mm_projector.parameters(): |
|
p.requires_grad = True |
|
|
|
|
|
conversation_lib.default_conversation = conversation_lib.conv_templates[args.conv_type] |
|
|
|
|
|
if lora_r > 0: |
|
lora_config = setup_lora_config(model, args) |
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
set_trainable_modules(model) |
|
|
|
|
|
def setup_lora_config(model, args): |
|
""" Configure LoRA settings for the model. """ |
|
|
|
def find_proj_layers(model, target_modules): |
|
""" Identify projection layers in the model for LoRA adaptation. """ |
|
linear_cls = torch.nn.Linear |
|
lora_module_names = set() |
|
for name, module in model.named_modules(): |
|
if (isinstance(module, linear_cls) and all( |
|
x not in name for x in ["grounding_encoder", "vision_tower", "mm_projector", "text_hidden_fcs"] |
|
) and any(x in name for x in target_modules)): |
|
lora_module_names.add(name) |
|
return sorted(list(lora_module_names)) |
|
|
|
|
|
lora_target_modules = args.lora_target_modules.split(",") |
|
lora_module_names = find_proj_layers(model, lora_target_modules) |
|
|
|
|
|
lora_config = LoraConfig( |
|
r=args.lora_r, lora_alpha=args.lora_alpha, target_modules=lora_module_names, lora_dropout=args.lora_dropout, |
|
bias="none", task_type="CAUSAL_LM" |
|
) |
|
return lora_config |
|
|
|
|
|
def set_trainable_modules(model): |
|
""" Make specified modules in the model trainable. """ |
|
trainable_modules = ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs", "region_encoder"] |
|
for name, param in model.named_parameters(): |
|
if any(module in name for module in trainable_modules): |
|
print(f"Making trainable: {name}, Shape: {param.shape}") |
|
param.requires_grad = True |
|
|
|
def count_parameters(model): |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
print('\033[92m' + "---- Total parameters: ----{}".format(total_params) + '\033[0m') |
|
print('\033[92m' + "---- Trainable parameters: ----{}".format(trainable_params) + '\033[0m') |
|
|
|
count_parameters(model) |
|
|
|
|
|
def initialize_datasets_and_loaders(args, tokenizer): |
|
world_size = torch.cuda.device_count() |
|
args.distributed = world_size > 1 |
|
|
|
|
|
common_ds_args = {"dataset_dir": args.dataset_dir, "tokenizer": tokenizer, |
|
"global_image_encoder": args.vision_tower, |
|
"epoch_samples": args.batch_size * args.grad_accumulation_steps * args.steps_per_epoch * world_size, |
|
"precision": args.precision, "image_size": args.image_size, |
|
"num_classes_per_sample": args.num_classes_per_sample} |
|
|
|
|
|
cap_train_dataset = HybridCapDataset( |
|
**common_ds_args, dataset=args.cap_dataset, sample_rate=[float(x) for x in args.cap_sample_rates.split(",")], |
|
batch_size=args.batch_size, ) if args.use_cap_data else None |
|
reg_train_dataset = HybridRegDataset( |
|
**common_ds_args, dataset=args.reg_dataset, sample_rate=[float(x) for x in args.reg_sample_rates.split(",")], |
|
batch_size=args.batch_size, ) if args.use_reg_data else None |
|
seg_train_dataset = HybridSegDataset( |
|
**common_ds_args, dataset=args.seg_dataset, sample_rate=[float(x) for x in args.segm_sample_rates.split(",")], |
|
semantic_segm_data=args.semantic_segm_data, refer_segm_data=args.refer_segm_data, |
|
batch_size=args.batch_size, ) if args.use_segm_data else None |
|
|
|
|
|
val_datasets = [] |
|
if not args.no_eval: |
|
val_dataset_classes = {'CocoCapVal': CocoCapDataset, |
|
'RefCOCOgRegVal': RefCocoGRegDataset, |
|
'VisGenomeRegVal': VisualGenomeRegDataset, |
|
'RefCOCOgSegmVal': ReferSegmDataset, |
|
'PsgGCGVal': OpenPsgGCGDataset, |
|
'RefCocoGCGVal': RefCOCOgGCGDataset, |
|
'FlickrGCGVal': Flickr30kGCGDataset, |
|
} |
|
for val_dataset_name in args.val_dataset.split('|'): |
|
val_dataset_class = val_dataset_classes.get(val_dataset_name) |
|
if val_dataset_class: |
|
if val_dataset_class == ReferSegmDataset: |
|
|
|
refer_segm_data = 'refcocog' |
|
all_datasets = refer_segm_data.split("||") |
|
for d in all_datasets: |
|
val_dataset_class = val_dataset_class( |
|
**common_ds_args, validation=True, refer_segm_data=d, split='val' |
|
) |
|
val_dataset_class._set_len(len(val_dataset_class.refer_segm_data[d]['images'])) |
|
val_datasets.append(val_dataset_class) |
|
else: |
|
val_datasets.append(val_dataset_class(**common_ds_args, validation=True)) |
|
|
|
return cap_train_dataset, reg_train_dataset, seg_train_dataset, val_datasets |
|
|
|
|
|
def setup_data_loaders(args, cap_train_dataset, reg_train_dataset, seg_train_dataset, val_datasets, tokenizer): |
|
sampler_args = {"shuffle": False, "drop_last": False} |
|
train_loader_args = {"batch_size": args.batch_size, "shuffle": False, "num_workers": args.workers, |
|
"pin_memory": False} |
|
val_loader_args = {"batch_size": args.val_batch_size, "shuffle": False, "num_workers": args.workers, |
|
"pin_memory": False} |
|
collate_fn_args_train = partial( |
|
custom_collate_fn, tokenizer=tokenizer, use_mm_start_end=args.use_mm_start_end, local_rank=args.local_rank, |
|
inference=False |
|
) |
|
inference_mode = args.mask_validation |
|
collate_fn_args_val = partial( |
|
custom_collate_fn, tokenizer=tokenizer, use_mm_start_end=args.use_mm_start_end, local_rank=args.local_rank, |
|
inference=inference_mode |
|
) |
|
|
|
|
|
cap_train_loader = torch.utils.data.DataLoader( |
|
cap_train_dataset, sampler=torch.utils.data.distributed.DistributedSampler( |
|
cap_train_dataset, **sampler_args |
|
), collate_fn=collate_fn_args_train, **train_loader_args |
|
) if cap_train_dataset is not None else None |
|
reg_train_loader = torch.utils.data.DataLoader( |
|
reg_train_dataset, sampler=torch.utils.data.distributed.DistributedSampler( |
|
reg_train_dataset, **sampler_args |
|
), collate_fn=collate_fn_args_train, **train_loader_args |
|
) if reg_train_dataset is not None else None |
|
seg_train_loader = torch.utils.data.DataLoader( |
|
seg_train_dataset, sampler=torch.utils.data.distributed.DistributedSampler( |
|
seg_train_dataset, **sampler_args |
|
), collate_fn=collate_fn_args_train, **train_loader_args |
|
) if seg_train_dataset is not None else None |
|
|
|
|
|
val_loader = None |
|
if val_datasets: |
|
combined_val_datasets = ConcatDataset(val_datasets) |
|
val_loader = torch.utils.data.DataLoader( |
|
combined_val_datasets, **val_loader_args, collate_fn=collate_fn_args_val, |
|
sampler=torch.utils.data.distributed.DistributedSampler(combined_val_datasets, **sampler_args), ) |
|
|
|
return cap_train_loader, reg_train_loader, seg_train_loader, val_loader |
|
|
|
|
|
def initialize_deepspeed(model, tokenizer, args): |
|
ds_config = {"train_micro_batch_size_per_gpu": args.batch_size, |
|
"gradient_accumulation_steps": args.grad_accumulation_steps, |
|
"optimizer": {"type": "AdamW", "params": {"lr": args.lr, "weight_decay": 0.0, |
|
"betas": (args.beta1, args.beta2)}}, |
|
"scheduler": {"type": "WarmupDecayLR", |
|
"params": {"total_num_steps": args.epochs * args.steps_per_epoch, "warmup_min_lr": 0, |
|
"warmup_max_lr": args.lr, "warmup_num_steps": 100, "warmup_type": "linear"}}, |
|
"fp16": {"enabled": args.precision == "fp16"}, "bf16": {"enabled": args.precision == "bf16"}, |
|
"gradient_clipping": 1.0, |
|
"zero_optimization": {"stage": 2, "contiguous_gradients": True, "overlap_comm": True, |
|
"reduce_scatter": True, "reduce_bucket_size": 5e8, |
|
"allgather_bucket_size": 5e8}, } |
|
|
|
model_engine, optimizer, _, scheduler = deepspeed.initialize( |
|
model=model, model_parameters=model.parameters(), collate_fn=partial( |
|
custom_collate_fn, tokenizer=tokenizer, use_mm_start_end=args.use_mm_start_end, local_rank=args.local_rank |
|
), config=ds_config |
|
) |
|
|
|
return model_engine, optimizer, scheduler |
|
|
|
|
|
def resume_training_from_checkpoint(model_engine, args): |
|
if args.auto_resume and not args.resume: |
|
resume = os.path.join(args.log_dir, "ckpt_model") |
|
if os.path.exists(resume): |
|
args.resume = resume |
|
|
|
if args.resume: |
|
load_path, client_state = model_engine.load_checkpoint(args.resume) |
|
with open(os.path.join(args.resume, "latest"), "r") as f: |
|
ckpt_dir = f.readlines()[0].strip() |
|
args.start_epoch = int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch |
|
print(f"Resume training from {args.resume}, start from epoch {args.start_epoch}") |
|
|
|
|
|
def main(args): |
|
tokenizer = setup_tokenizer_and_special_tokens(args) |
|
model = initialize_model(args, tokenizer) |
|
prepare_model_for_training(model, tokenizer, args) |
|
|
|
model_engine, optimizer, scheduler = initialize_deepspeed(model, tokenizer, args) |
|
resume_training_from_checkpoint(model_engine, args) |
|
|
|
cap_train_dataset, reg_train_dataset, seg_train_dataset, val_datasets = ( |
|
initialize_datasets_and_loaders(args, tokenizer)) |
|
cap_train_loader, reg_train_loader, seg_train_loader, val_loader = ( |
|
setup_data_loaders(args, cap_train_dataset, reg_train_dataset, seg_train_dataset, val_datasets, tokenizer)) |
|
|
|
|
|
active_dataloaders = [] |
|
weights = [] |
|
|
|
if args.use_cap_data: |
|
active_dataloaders.append(('cap', cap_train_loader)) |
|
weights.append(args.weight_cap) |
|
if args.use_reg_data: |
|
active_dataloaders.append(('reg', reg_train_loader)) |
|
weights.append(args.weight_reg) |
|
if args.use_segm_data: |
|
active_dataloaders.append(('seg', seg_train_loader)) |
|
weights.append(args.weight_segm) |
|
|
|
|
|
assert active_dataloaders, "Error: At least one dataset (segm, reg, or cap) must be active." |
|
|
|
dataset_iters = {'cap': iter(cap_train_loader) if args.use_cap_data else None, |
|
'reg': iter(reg_train_loader) if args.use_reg_data else None, |
|
'seg': iter(seg_train_loader) if args.use_segm_data else None, } |
|
|
|
writer = initialize_environment(args) |
|
|
|
if args.eval_only: |
|
cur_val_loss = validate_model_performance(val_loader, model_engine, 0, writer, args)[0] |
|
exit() |
|
|
|
epoch_seeds = [random.randint(0, 100000) for _ in range(args.epochs)] |
|
dataset_choices = [idx for idx, _ in enumerate(active_dataloaders)] |
|
|
|
best_giou, best_ciou, best_val_loss = 0.0, 0.0, np.inf |
|
for epoch in range(args.start_epoch, args.epochs): |
|
random.seed(epoch_seeds[epoch]) |
|
|
|
step_choices = random.choices(dataset_choices, weights=weights, k=args.steps_per_epoch) |
|
|
|
dataset_iters = train( |
|
active_dataloaders, model_engine, epoch, scheduler, writer, dataset_iters, args, step_choices |
|
) |
|
|
|
if args.mask_validation: |
|
giou, ciou = validate_model_performance(val_loader, model_engine, epoch, writer, args) |
|
is_best = giou > best_giou |
|
best_giou = max(giou, best_giou) |
|
best_ciou = ciou if is_best else best_ciou |
|
if args.local_rank == 0: |
|
print(f"Epoch: {epoch}, giou: {giou}, ciou: {ciou}, best_giou: {best_giou}, best_ciou: {best_ciou}") |
|
save_checkpoint(model_engine, args, epoch, 'giou-ciou', f"{giou:.4f}-{ciou:.4f}", is_best) |
|
else: |
|
cur_val_loss = validate_model_performance(val_loader, model_engine, epoch, writer, args) |
|
is_best = cur_val_loss < best_val_loss |
|
best_val_loss = min(cur_val_loss, best_val_loss) |
|
if args.local_rank == 0: |
|
print(f"Epoch: {epoch}, Current Validation Loss: {cur_val_loss:.4f}, Best Validation Loss: {best_val_loss:}") |
|
save_checkpoint(model_engine, args, epoch, 'loss', f"{cur_val_loss:.4f}", is_best) |
|
|
|
|
|
def save_checkpoint(model_engine, args, epoch, metric_name, metric_value, is_best): |
|
""" Saves the model checkpoint. """ |
|
|
|
save_dir_name = "ckpt_model_best" if is_best else "ckpt_model_last_epoch" |
|
save_dir = os.path.join(args.log_dir, save_dir_name) |
|
|
|
if args.local_rank == 0: |
|
os.makedirs(save_dir, exist_ok=True) |
|
ckpt_filename = f"epoch_{epoch}_val_{metric_name}_{metric_value}.pth" |
|
torch.save({"epoch": epoch, f"val_{metric_name}": metric_value}, os.path.join(save_dir, ckpt_filename)) |
|
torch.distributed.barrier() |
|
model_engine.save_checkpoint(save_dir) |
|
|
|
|
|
def train(active_datasets, model, epoch, scheduler, writer, dataset_iters, args, step_choices): |
|
"""Main training loop.""" |
|
|
|
def get_next_input(iterator, data_loader): |
|
"""Retrieve next input from the iterator, or reinitialize if necessary.""" |
|
try: |
|
return next(iterator), iterator |
|
except StopIteration: |
|
new_iterator = iter(data_loader) |
|
return next(new_iterator), new_iterator |
|
|
|
def log_progress(): |
|
"""Log training progress.""" |
|
if global_step % args.print_freq == 0: |
|
if args.distributed: |
|
for tracker in trackers.values(): |
|
tracker.all_reduce() |
|
|
|
if args.local_rank == 0: |
|
progress.display(global_step + 1) |
|
for key, tracker in trackers.items(): |
|
writer.add_scalar(f"train/{key}", tracker.avg, global_step) |
|
writer.add_scalar("metrics/total_secs_per_batch", batch_time.avg, global_step) |
|
writer.add_scalar("metrics/data_secs_per_batch", data_time.avg, global_step) |
|
|
|
for tracker in trackers.values(): |
|
tracker.reset() |
|
|
|
batch_time = AverageMeter("Time", ":.4f") |
|
data_time = AverageMeter("Data", ":.4f") |
|
trackers = {"loss": AverageMeter("Loss", ":.4f"), |
|
"ce_loss": AverageMeter("CeLoss", ":.4f"), |
|
"mask_bce_loss": AverageMeter("MaskBCELoss", ":.4f"), |
|
"mask_dice_loss": AverageMeter("MaskDICELoss", ":.4f"), |
|
"mask_loss": AverageMeter("MaskLoss", ":.4f")} |
|
progress = ProgressMeter(args.steps_per_epoch, list(trackers.values()), prefix=f"Epoch: [{epoch}]") |
|
|
|
model.train() |
|
end = time.time() |
|
for global_step in range(args.steps_per_epoch): |
|
for _ in range(args.grad_accumulation_steps): |
|
|
|
dataset_type, data_loader = active_datasets[step_choices[global_step]] |
|
data_batch, new_iter = get_next_input(dataset_iters[dataset_type], data_loader) |
|
dataset_iters[dataset_type] = new_iter |
|
|
|
data_time.update(time.time() - end) |
|
|
|
data_batch = dict_to_cuda(data_batch) |
|
for key in ["global_enc_images", "grounding_enc_images"]: |
|
if data_batch[key] is not None: |
|
data_batch[key] = data_batch[key].bfloat16() |
|
|
|
output_dict = model(**data_batch) |
|
|
|
|
|
for key, tracker in trackers.items(): |
|
if key in output_dict: |
|
tracker.update(output_dict[key].item(), data_batch["global_enc_images"].size(0)) |
|
|
|
model.backward(output_dict["loss"]) |
|
model.step() |
|
|
|
batch_time.update(time.time() - end) |
|
end = time.time() |
|
log_progress() |
|
|
|
if global_step != 0: |
|
curr_lr = scheduler.get_last_lr() |
|
if args.local_rank == 0: |
|
writer.add_scalar("train/lr", curr_lr[0], global_step) |
|
|
|
return dataset_iters |
|
|
|
|
|
def validate_model_performance(validation_loader, training_model, current_epoch, tensorboard_writer, args): |
|
if args.mask_validation: |
|
|
|
trackers = {"intersection": AverageMeter("Intersec", ":.4f", Summary.SUM), |
|
"union": AverageMeter("Union", ":.4f", Summary.SUM), |
|
"gIoU": AverageMeter("gIoU", ":.4f", Summary.SUM)} |
|
|
|
training_model.eval() |
|
for data_batch in tqdm.tqdm(validation_loader): |
|
|
|
data_batch = dict_to_cuda(data_batch) |
|
for key in ["global_enc_images", "grounding_enc_images"]: |
|
data_batch[key] = data_batch[key].bfloat16() |
|
torch.cuda.empty_cache() |
|
|
|
with torch.no_grad(): |
|
results = training_model(**data_batch) |
|
|
|
predictions = results["pred_masks"] |
|
gt_masks = results["gt_masks"][0].int() |
|
|
|
|
|
predicted_masks = (predictions[0] > 0).int() |
|
assert len(predictions) == 1 |
|
|
|
intersection, union, accuracy_iou = 0.0, 0.0, 0.0 |
|
for target, prediction in zip(gt_masks, predicted_masks): |
|
intersect, union_, _ = intersectionAndUnionGPU( |
|
prediction.contiguous().clone(), target.contiguous(), 2, ignore_index=255 |
|
) |
|
intersection += intersect |
|
union += union_ |
|
accuracy_iou += intersect / (union_ + 1e-5) |
|
|
|
accuracy_iou[union_ == 0] += 1.0 |
|
|
|
intersection, union = intersection.cpu().numpy(), union.cpu().numpy() |
|
accuracy_iou = accuracy_iou.cpu().numpy() / gt_masks.shape[0] |
|
trackers["intersection"].update(intersection) |
|
trackers["union"].update(union) |
|
trackers["gIoU"].update(accuracy_iou, n=gt_masks.shape[0]) |
|
|
|
for meter in trackers.values(): |
|
meter.all_reduce() |
|
|
|
iou_per_class = trackers["intersection"].sum / (trackers["union"].sum + 1e-10) |
|
class_iou = iou_per_class[1] |
|
global_iou = trackers["gIoU"].avg[1] |
|
|
|
if args.local_rank == 0: |
|
tensorboard_writer.add_scalar("val/giou", global_iou, current_epoch) |
|
tensorboard_writer.add_scalar("val/ciou", class_iou, current_epoch) |
|
print("giou: {:.4f}, ciou: {:.4f}".format(global_iou, class_iou)) |
|
|
|
return global_iou, class_iou |
|
else: |
|
|
|
trackers = {"loss": AverageMeter("Loss", ":.4f"), "ce_loss": AverageMeter("CeLoss", ":.4f"), |
|
"mask_bce_loss": AverageMeter("MaskBCELoss", ":.4f"), |
|
"mask_dice_loss": AverageMeter("MaskDICELoss", ":.4f"), |
|
"mask_loss": AverageMeter("MaskLoss", ":.4f")} |
|
|
|
|
|
|
|
training_model.train() |
|
|
|
for data_batch in tqdm.tqdm(validation_loader): |
|
|
|
data_batch = dict_to_cuda(data_batch) |
|
for key in ["global_enc_images", "grounding_enc_images"]: |
|
if data_batch[key] is not None: |
|
data_batch[key] = data_batch[key].bfloat16() |
|
torch.cuda.empty_cache() |
|
|
|
with torch.no_grad(): |
|
predictions = training_model(**data_batch) |
|
|
|
for key, tracker in trackers.items(): |
|
tracker.update(predictions[key].item(), data_batch["global_enc_images"].size(0)) |
|
|
|
|
|
for tracker in trackers.values(): |
|
tracker.all_reduce() |
|
|
|
avg_val_loss = trackers["ce_loss"].avg |
|
|
|
if args.local_rank == 0: |
|
tensorboard_writer.add_scalar("val/loss", avg_val_loss, current_epoch) |
|
|
|
return avg_val_loss |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args(sys.argv[1:]) |
|
main(args) |
|
|