from accelerate.utils import write_basic_config write_basic_config() import warnings import logging import sys import os import math import shutil from pathlib import Path import numpy as np from einops import rearrange import accelerate from collections import defaultdict import time from PIL import Image, ImageDraw from tqdm.auto import tqdm import torch torch.cuda.empty_cache() import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from packaging import version from diffusers import EulerDiscreteScheduler from diffusers.models import AutoencoderKLTemporalDecoder from diffusers.optimization import get_scheduler from diffusers.utils import is_wandb_available from diffusers.utils.torch_utils import is_compiled_module from src.utils import parse_args, encode_video_image, get_add_time_ids, get_samples from src.datasets.dataset_utils import get_dataloader from src.models import UNetSpatioTemporalConditionModel, ControlNetModel from src.pipelines import StableVideoControlNullModelPipeline if not is_wandb_available(): warnings.warn("Make sure to install wandb if you want to use it for logging during training.") else: import wandb logger = get_logger(__name__, log_level="INFO") def get_latest_checkpoint(checkpoint_dir): dirs = os.listdir(checkpoint_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None return path def load_model_hook(models, input_dir): for _ in range(len(models)): # pop models so that they are not loaded again model = models.pop() if isinstance(model, UNetSpatioTemporalConditionModel): load_model = UNetSpatioTemporalConditionModel.from_pretrained(input_dir, subfolder="unet") elif isinstance(model, ControlNetModel): load_model = ControlNetModel.from_pretrained(input_dir, subfolder="control_net") else: raise Exception("Only UNetSpatioTemporalConditionModel and ControlNetModel are supported for loading.") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) del load_model # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_models_accelerate(models, weights, output_dir, vae, feature_extractor, image_encoder, noise_scheduler): for i, model in enumerate(models): if isinstance(model, UNetSpatioTemporalConditionModel): model.save_pretrained(os.path.join(output_dir, "unet"), safe_serialization=False) elif isinstance(model, ControlNetModel): model.save_pretrained(os.path.join(output_dir, "control_net"), safe_serialization=False) else: raise Exception("Only UNetSpatioTemporalConditionModel and ControlNetModel are supported for saving.") # Also save other (frozen) components, just so they are found in the same checkpoint # vae.save_pretrained(os.path.join(output_dir, "vae"), safe_serialization=False) # feature_extractor.save_pretrained(os.path.join(output_dir, "feature_extractor"), safe_serialization=False) # image_encoder.save_pretrained(os.path.join(output_dir, "image_encoder"), safe_serialization=False) # noise_scheduler.save_pretrained(os.path.join(output_dir, "scheduler"), safe_serialization=False) # make sure to pop weight so that corresponding model is not saved again weights.pop() class TrainVideoControlnet: def __init__(self, args): self.args = args def get_accelerator(self): logging_dir = Path(self.args.output_dir, self.args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=self.args.gradient_accumulation_steps, mixed_precision=self.args.mixed_precision, log_with=self.args.report_to, project_config=accelerator_project_config, ) self.accelerator = accelerator def log_setup(self): # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(self.accelerator.state, main_process_only=False) # If passed along, set the training seed now. if self.args.seed is not None: set_seed(self.args.seed) # Handle the repository creation if self.accelerator.is_main_process and self.args.output_dir is not None: os.makedirs(self.args.output_dir, exist_ok=True) def load_models_from_pretrained(self): # Load scheduler, tokenizer and models. unet = UNetSpatioTemporalConditionModel.from_pretrained( self.args.pretrained_model_name_or_path, subfolder="unet", variant=None, low_cpu_mem_usage=True, num_frames=self.args.clip_length ) # Pretrained (frozen) models from stable-video-diffusion model_path = "stabilityai/stable-video-diffusion-img2vid-xt" variant = "fp16" vae = AutoencoderKLTemporalDecoder.from_pretrained( model_path, subfolder="vae", revision=self.args.revision, variant=variant ) image_encoder = CLIPVisionModelWithProjection.from_pretrained( model_path, subfolder="image_encoder", revision=self.args.revision, variant=variant ) feature_extractor = CLIPImageProcessor.from_pretrained( model_path, subfolder="feature_extractor", revision=self.args.revision ) noise_scheduler = EulerDiscreteScheduler.from_pretrained( model_path, subfolder="scheduler" ) null_model_unet = UNetSpatioTemporalConditionModel.from_pretrained( model_path, subfolder="unet", variant=None, low_cpu_mem_usage=True, num_frames=self.args.clip_length ) return unet, vae, image_encoder, feature_extractor, noise_scheduler, null_model_unet def get_dataloaders(self): train_dataset, train_loader = get_dataloader(self.args.data_root, self.args.dataset_name, if_train=True, clip_length=self.args.clip_length, batch_size=self.args.train_batch_size, num_workers=self.args.dataloader_num_workers, shuffle=True, image_height=self.args.train_H, image_width=self.args.train_W, non_overlapping_clips=self.args.non_overlapping_clips, bbox_masking_prob=self.args.bbox_masking_prob ) val_dataset, val_loader = get_dataloader(self.args.data_root, self.args.dataset_name, if_train=False, clip_length=self.args.clip_length, batch_size=self.args.num_demo_samples, num_workers=self.args.dataloader_num_workers, shuffle=True, image_height=self.args.train_H, image_width=self.args.train_W, non_overlapping_clips=True, ) # demo_samples = get_samples(val_loader, self.args.num_demo_samples, show_progress=True) return train_dataset, train_loader, val_dataset, val_loader def get_sigmas(self, timesteps, n_dim=5, dtype=torch.float32): sigmas = self.noise_scheduler.sigmas.to(device=self.accelerator.device, dtype=dtype) schedule_timesteps = self.noise_scheduler.timesteps.to(self.accelerator.device) timesteps = timesteps.to(self.accelerator.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma def setup_training(self): self.get_accelerator() self.log_setup() self.last_save_time = 0 if self.accelerator.is_main_process: self.last_save_time = time.time() # Setup devices accelerator_device = self.accelerator.device unet, vae, image_encoder, feature_extractor, noise_scheduler, null_model_unet = self.load_models_from_pretrained() # freeze parameters of models to save more memory vae.requires_grad_(False) image_encoder.requires_grad_(False) unet.requires_grad_(False) null_model_unet.requires_grad_(False) # Load the model assert self.args.train_H % 8 == 0 and self.args.train_W % 8 == 0 self.bbox_embedding_shape = (4, self.args.train_H // 8, self.args.train_W // 8) ctrlnet = ControlNetModel.from_unet(unet, action_dim=5, bbox_embedding_shape=self.bbox_embedding_shape) ctrlnet.requires_grad_(True) # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. self.weight_dtype = self.train_weights_dtype = torch.float32 if self.accelerator.mixed_precision == "fp16": self.weight_dtype = torch.float16 elif self.accelerator.mixed_precision == "bf16": self.weight_dtype = torch.bfloat16 # Move unet, vae and text_encoder to device and cast to weight_dtype vae.to(accelerator_device, dtype=self.weight_dtype) image_encoder.to(accelerator_device, dtype=self.weight_dtype) unet.to(accelerator_device, dtype=self.weight_dtype) null_model_unet.to(accelerator_device, dtype=self.weight_dtype) # `accelerate` 0.16.0 will have better support for customized saving assert version.parse(accelerate.__version__) >= version.parse("0.16.0") def save_model_hook(models, weights, output_dir): if self.accelerator.is_main_process: save_models_accelerate(models, weights, output_dir, vae, feature_extractor, image_encoder, noise_scheduler) self.accelerator.register_save_state_pre_hook(save_model_hook) self.accelerator.register_load_state_pre_hook(load_model_hook) if self.args.enable_gradient_checkpointing: unet.enable_gradient_checkpointing() if self.args.scale_lr: self.args.learning_rate = (self.args.learning_rate * self.args.gradient_accumulation_steps * self.args.per_gpu_batch_size * self.accelerator.num_processes) optimizer = torch.optim.AdamW( list(ctrlnet.parameters()), # Include all trainable parameters lr=self.args.learning_rate, betas=(self.args.adam_beta1, self.args.adam_beta2), weight_decay=self.args.adam_weight_decay, eps=self.args.adam_epsilon, ) train_dataset, train_loader, _, val_loader = self.get_dataloaders() overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_loader) / self.args.gradient_accumulation_steps) if self.args.max_train_steps is None: self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( self.args.lr_scheduler, optimizer=optimizer, num_warmup_steps=self.args.lr_warmup_steps * self.accelerator.num_processes, num_training_steps=self.args.max_train_steps * self.accelerator.num_processes, ) # Prepare everything with our `accelerator`. unet, ctrlnet, optimizer, train_loader, lr_scheduler, null_model_unet = self.accelerator.prepare( unet, ctrlnet, optimizer, train_loader, lr_scheduler, null_model_unet ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_loader) / self.args.gradient_accumulation_steps) if overrode_max_train_steps: self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs self.args.num_train_epochs = math.ceil(self.args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if self.accelerator.is_main_process: if not self.args.disable_wandb: self.accelerator.init_trackers(self.args.project_name, config=vars(self.args), init_kwargs={"wandb": {"dir": self.args.output_dir, "name": self.args.run_name, "entity": self.args.wandb_entity}}) else: print("WANDB LOGGING DISABLED") self.unet = unet self.ctrlnet = ctrlnet self.vae = vae self.image_encoder = image_encoder self.feature_extractor = feature_extractor self.noise_scheduler = noise_scheduler self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.lr_scheduler = lr_scheduler self.train_dataset = train_dataset self.null_model_unet = null_model_unet def print_train_info(self): total_batch_size = self.args.train_batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(self.train_dataset)}") logger.info(f" Num Epochs = {self.args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {self.args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {self.args.max_train_steps}") logger.info(f" Number of processes = {self.accelerator.num_processes}") def load_checkpoint(self): self.first_epoch = 0 self.global_step = 0 if not self.args.resume_from_checkpoint: self.initial_global_step = 0 return if self.args.resume_from_checkpoint != "latest": path = os.path.basename(self.args.resume_from_checkpoint) else: # Get the most recent checkpoint path = get_latest_checkpoint(self.args.output_dir) if path is None: self.accelerator.print(f"Checkpoint '{self.args.resume_from_checkpoint}' does not exist. Starting a new training run.") self.args.resume_from_checkpoint = None self.initial_global_step = 0 else: self.accelerator.print(f"Resuming from checkpoint {path}") self.accelerator.load_state(os.path.join(self.args.output_dir, path)) self.initial_global_step = self.global_step = int(path.split("-")[1]) # self.first_epoch = global_step // num_update_steps_per_epoch # Not calculating first epoch right when using multiple processes. Let's default to using more epochs def unwrap_model(self, model): model = self.accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model def log_step(self, step_loss, train_loss): logs = {"step_loss": step_loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]} self.progress_bar.set_postfix(**logs) # Checks if the accelerator has performed an optimization step behind the scenes if self.accelerator.sync_gradients: self.progress_bar.update(1) self.global_step += 1 log_plot = {"train_loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0],} if self.args.add_bbox_frame_conditioning: log_plot["|attn_rz_weight|"] = self.unet.get_attention_rz_weight() self.accelerator.log(log_plot, step=self.global_step) train_loss = 0.0 def save_checkpoint(self): # Checks if the accelerator has performed an optimization step behind the scenes (only checkpoint after the gradient accumulation steps) if not self.accelerator.sync_gradients: return args = self.args # Save if checkpointing step reached or job is about to expire save_checkpoint_time = args.checkpointing_time > 0 and (time.time() - self.last_save_time >= args.checkpointing_time) # Save if number of steps for checkpointing reached save_checkpoint_steps = self.global_step % args.checkpointing_steps == 0 or save_checkpoint_time if self.accelerator.is_main_process and (save_checkpoint_time or save_checkpoint_steps): if save_checkpoint_time: print("Saving checkpoint due to time. Time elapsed:", time.time() - self.last_save_time) self.last_save_time = time.time() # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints") logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{self.global_step}") self.accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") def run_inference_with_pipeline(self, pipeline, demo_samples, log_dict): args = self.args for sample_i, sample in tqdm(enumerate(demo_samples), desc="Validation", total=args.num_demo_samples): action_type = sample['action_type'].unsqueeze(0) frames = pipeline(sample['image_init'] if not args.generate_bbox else sample['bbox_init'], cond_images=sample['bbox_images'].unsqueeze(0) if not args.generate_bbox else sample['gt_clip'].unsqueeze(0), action_type=action_type, height=self.train_dataset.resize_height, width=self.train_dataset.resize_width, decode_chunk_size=8, motion_bucket_id=127, fps=args.fps, num_inference_steps=args.num_inference_steps, num_frames=args.clip_length, control_condition_scale=args.conditioning_scale, min_guidance_scale=args.min_guidance_scale, max_guidance_scale=args.max_guidance_scale, noise_aug_strength=args.noise_aug_strength, generator=self.generator, output_type='pt').frames[0] #frames = F.interpolate(frames, (train_dataset.orig_H, train_dataset.orig_W)).detach().cpu().numpy()*255 frames = frames.detach().cpu().numpy()*255 frames = frames.astype(np.uint8) log_dict["generated_videos"].append(wandb.Video(frames, fps=args.fps)) if sample.get('bbox_images_np') is not None: # Add action type text to ground truth bounding box frames bbox_frames = sample['bbox_images_np'].copy() action_id = action_type.item() action_name = {0: "Normal", 1: "Ego", 2: "Ego/Veh", 3: "Veh", 4: "Veh/Veh"} action_text = f"Action: {action_name[action_id]} ({action_id})" for i in range(bbox_frames.shape[0]): # Convert numpy array to PIL Image frame = Image.fromarray(bbox_frames[i].transpose(1, 2, 0)) draw = ImageDraw.Draw(frame) # Add text in top right corner text_position = (frame.width - 10, 10) # 10 pixels from top, 10 pixels from right draw.text(text_position, action_text, fill=(255, 255, 255), anchor="ra") # Add video name in top left corner text_position = (10, 10) # 10 pixels from top, 10 pixels from right draw.text(text_position, sample['vid_name'], fill=(255, 255, 255), anchor="la") # Convert back to numpy array bbox_frames[i] = np.array(frame).transpose(2, 0, 1) log_dict["gt_bbox_frames"].append(wandb.Video(bbox_frames, fps=args.fps)) log_dict["gt_videos"].append(wandb.Video(sample['gt_clip_np'], fps=args.fps)) # frame_bboxes = wandb_frames_with_bbox(frames, sample['objects_tensors'], (train_dataset.orig_W, train_dataset.orig_H)) # log_dict["frames_with_bboxes_{}".format(sample_i)] = frame_bboxes return log_dict def load_pipeline(self): # NOTE: Pipeline used for inference at validation step, can change for different pipelines # Compatibility with pretrained models from ctrlv import src import src.models.unet_spatio_temporal_condition as unet_module sys.modules['ctrlv'] = src sys.modules['ctrlv.models'] = src.models sys.modules['ctrlv.models.unet_spatio_temporal_condition'] = unet_module pipeline = StableVideoControlNullModelPipeline.from_pretrained(self.args.pretrained_model_name_or_path, unet=self.unwrap_model(self.unet), image_encoder=self.unwrap_model(self.image_encoder), vae=self.unwrap_model(self.vae), controlnet=self.unwrap_model(self.ctrlnet), null_model=self.unwrap_model(self.null_model_unet), feature_extractor=self.feature_extractor, revision=self.args.revision, variant=self.args.variant, torch_dtype=self.weight_dtype, ) pipeline = pipeline.to(self.accelerator.device) pipeline.set_progress_bar_config(disable=True) return pipeline def validation_step(self, save_pipeline=False): logger.info("Running validation... ") log_dict = defaultdict(list) with torch.autocast(str(self.accelerator.device).replace(":0", ""), enabled=self.accelerator.mixed_precision == "fp16"): pipeline = self.load_pipeline() if self.demo_samples is None: self.demo_samples = get_samples(self.val_loader, self.args.num_demo_samples, show_progress=True) self.ctrlnet.eval() log_dict = self.run_inference_with_pipeline(pipeline, self.demo_samples, log_dict) for tracker in self.accelerator.trackers: if tracker.name == "wandb": tracker.log(log_dict) if save_pipeline: pipeline.save_pretrained(self.args.output_dir) del pipeline, log_dict # torch.cuda.empty_cache() logger.info("Validation complete. ") def train_step(self, batch): # Aliases args = self.args accelerator = self.accelerator accelerator_device = self.accelerator.device ctrlnet, unet, vae, image_encoder, feature_extractor = self.ctrlnet, self.unet, self.vae, self.image_encoder, self.feature_extractor optimizer, noise_scheduler, lr_scheduler = self.optimizer, self.noise_scheduler, self.lr_scheduler weight_dtype = self.weight_dtype train_weights_dtype = self.train_weights_dtype train_loss = 0.0 self.ctrlnet.train() with accelerator.accumulate(self.ctrlnet): # Forward pass batch_size, video_length = batch['clips'].shape[0], batch['clips'].shape[1] initial_images = batch['clips'][:,0,:,:,:] if not self.args.generate_bbox else batch['bbox_images'][:,0,:,:,:] # only use the first frame # check device if vae.device != accelerator_device: vae.to(accelerator_device) image_encoder.to(accelerator_device) initial_images.to(accelerator_device) # Encode input image encoder_hidden_states = encode_video_image(initial_images, feature_extractor, weight_dtype, image_encoder).unsqueeze(1) encoder_hidden_states = encoder_hidden_states.to(dtype=train_weights_dtype).to(accelerator_device) # Encode input image using VAE conditional_latents = vae.encode(initial_images.to(accelerator_device).to(weight_dtype)).latent_dist.sample() conditional_latents = conditional_latents.to(accelerator_device).to(train_weights_dtype) # Encode bbox image using VAE # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] bbox_frames = rearrange(batch['bbox_images'] if not args.generate_bbox else batch['clips'], 'b f c h w -> (b f) c h w').to(vae.device).to(weight_dtype) # Get the selected option from the batch (This is the accident type flag) if self.args.use_action_conditioning: action_type = batch['action_type'] # Shape: [batch_size, 1] else: action_type = None # Encode using VAE (now with standard 3 channels) bbox_em = vae.encode(bbox_frames).latent_dist.sample() bbox_em = rearrange(bbox_em, '(b f) c h w -> b f c h w', f=video_length).to(accelerator_device).to(train_weights_dtype) # Encode clip frames using VAE # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] frames = rearrange(batch['clips'] if not args.generate_bbox else batch['bbox_images'], 'b f c h w -> (b f) c h w').to(vae.device).to(weight_dtype) latents = vae.encode(frames).latent_dist.sample() latents = rearrange(latents, '(b f) c h w -> b f c h w', f=video_length).to(accelerator_device).to(train_weights_dtype) target_latents = latents = latents * vae.config.scaling_factor del batch, frames noise = torch.randn_like(latents) indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=noise_scheduler.timesteps.device).long() timesteps = noise_scheduler.timesteps[indices].to(accelerator_device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Scale the noisy latents for the UNet sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) # inp_noisy_latents = noise_scheduler.scale_model_input(noisy_latents, timesteps) inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5) added_time_ids = get_add_time_ids( fps=args.fps-1, motion_bucket_id=127, noise_aug_strength=args.noise_aug_strength, dtype=weight_dtype, batch_size=batch_size, unet=unet ).to(accelerator_device) # Conditioning dropout to support classifier-free guidance during inference. For more details # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. # Addapted from https://github.com/huggingface/diffusers/blob/0d2d424fbef933e4b81bea20a660ee6fc8b75ab0/docs/source/en/training/instructpix2pix.md if args.conditioning_dropout_prob is not None: random_p = torch.rand(batch_size, device=accelerator_device, generator=self.generator) # Sample masks for the edit prompts. prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(batch_size, 1, 1) # Final text conditioning (initial image CLIP embedding) null_conditioning = torch.zeros_like(encoder_hidden_states) encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) # Sample masks for the original images. image_mask_dtype = conditional_latents.dtype image_mask = 1 - ( (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype) * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype) ) image_mask = image_mask.reshape(batch_size, 1, 1, 1) # Final image conditioning. conditional_latents = image_mask * conditional_latents # Bbox conditioning masking if args.contiguous_bbox_masking_prob is not None and args.contiguous_bbox_masking_prob > 0: random_p = torch.rand(batch_size, device=accelerator_device, generator=self.generator) if args.contiguous_bbox_masking_start_ratio is not None and args.contiguous_bbox_masking_start_ratio > 0: # Among the masked samples, randomly select some to mask from the start of the video (and the rest from the end) random_threshold = args.contiguous_bbox_masking_prob * args.contiguous_bbox_masking_start_ratio sample_mask_start = (random_p < random_threshold).view(batch_size, 1, 1, 1, 1) sample_mask_end = ((random_p >= random_threshold) & (random_p < args.contiguous_bbox_masking_prob)).view(batch_size, 1, 1, 1, 1) min_bbox_mask_idx_start, max_bbox_mask_idx_start = 0, (video_length + 1) # TODO: Determine schedule (decrease min mask idx over time) bbox_mask_idx_start = torch.randint(min_bbox_mask_idx_start, max_bbox_mask_idx_start, (batch_size,), device=accelerator_device) mask_cond_start = bbox_mask_idx_start > torch.arange(args.clip_length, device=accelerator_device).view(1, args.clip_length, 1, 1, 1) bbox_em = torch.where(mask_cond_start & sample_mask_start, self.unwrap_model(self.ctrlnet).bbox_null_embedding, bbox_em) else: sample_mask_end = (random_p < args.contiguous_bbox_masking_prob).view(batch_size, 1, 1, 1, 1) min_bbox_mask_idx, max_bbox_mask_idx = 0, (video_length + 1) # TODO: Determine schedule (decrease min mask idx over time) bbox_mask_idx = torch.randint(min_bbox_mask_idx, max_bbox_mask_idx, (batch_size,), device=accelerator_device) mask_cond = bbox_mask_idx <= torch.arange(args.clip_length, device=accelerator_device).view(1, args.clip_length, 1, 1, 1) bbox_em = torch.where(mask_cond & sample_mask_end, self.unwrap_model(self.ctrlnet).bbox_null_embedding, bbox_em) # Concatenate the `original_image_embeds` with the `noisy_latents`. # conditional_latents = unet.encode_bbox_frame(conditional_latents, None) conditional_latents = conditional_latents.unsqueeze(1).repeat(1, self.args.clip_length, 1, 1, 1) concatenated_noisy_latents = torch.cat([inp_noisy_latents, conditional_latents], dim=2) added_time_ids = added_time_ids.to(dtype=train_weights_dtype) down_block_additional_residuals, mid_block_additional_residuals = ctrlnet( concatenated_noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids, control_cond=bbox_em, action_type=action_type, conditioning_scale=args.conditioning_scale, return_dict=False, ) if args.empty_cuda_cache: torch.cuda.empty_cache() model_pred = unet(sample=concatenated_noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residuals=mid_block_additional_residuals,).sample # Denoise the latents c_out = -sigmas / ((sigmas**2 + 1)**0.5) c_skip = 1 / (sigmas**2 + 1) denoised_latents = model_pred * c_out + c_skip * noisy_latents weighting = (1 + sigmas ** 2) * (sigmas**-2.0) # MSE loss step_loss = torch.mean((weighting.float() * (denoised_latents.float() - target_latents.float()) ** 2).reshape(target_latents.shape[0], -1), dim=1,) step_loss = step_loss.mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(step_loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps # Backpropagate accelerator.backward(step_loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() return step_loss, train_loss def train(self): """ Main training loop """ self.print_train_info() # Potentially load in the weights and states from a previous save self.load_checkpoint() args = self.args accelerator = self.accelerator self.progress_bar = tqdm( range(0, args.max_train_steps), initial=self.initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) self.generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None self.demo_samples = None # Lazy load for samples test_valid = False for _ in range(self.first_epoch, args.num_train_epochs): for _, batch in enumerate(self.train_loader): # Check for for validation if accelerator.sync_gradients: if test_valid or (self.global_step % args.validation_steps == 0 and self.global_step != 0) or (args.val_on_first_step and self.global_step == 0): if accelerator.is_main_process: self.validation_step() accelerator.wait_for_everyone() test_valid = False # Training step step_loss, train_loss = self.train_step(batch) # Log info self.log_step(step_loss, train_loss) # Potentially save checkpoint self.save_checkpoint() # if args.empty_cuda_cache: # torch.cuda.empty_cache() # if global_step >= args.max_train_steps: # break accelerator.wait_for_everyone() # Run a final round of inference if accelerator.is_main_process: logger.info("Running inference before terminating...") self.validation_step() logging.info("Finished training.") accelerator.end_training() def main(): args = parse_args() try: train_controlnet = TrainVideoControlnet(args) train_controlnet.setup_training() # Load models, setup logging and define training config train_controlnet.train() # Train! except KeyboardInterrupt: if hasattr(train_controlnet, "accelerator"): train_controlnet.accelerator.end_training() if is_wandb_available(): wandb.finish() print("Keyboard interrupt: shutdown requested... Exiting.") exit() except Exception: import sys, traceback if is_wandb_available(): wandb.finish() traceback.print_exc(file=sys.stdout) sys.exit(0) if __name__ == '__main__': main()