Ctrl-Crash / train_video_controlnet.py
alexnasa's picture
Upload 52 files
8e16429 verified
raw
history blame
39.2 kB
from accelerate.utils import write_basic_config
write_basic_config()
import warnings
import logging
import sys
import os
import math
import shutil
from pathlib import Path
import numpy as np
from einops import rearrange
import accelerate
from collections import defaultdict
import time
from PIL import Image, ImageDraw
from tqdm.auto import tqdm
import torch
torch.cuda.empty_cache()
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from packaging import version
from diffusers import EulerDiscreteScheduler
from diffusers.models import AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from diffusers.utils import is_wandb_available
from diffusers.utils.torch_utils import is_compiled_module
from src.utils import parse_args, encode_video_image, get_add_time_ids, get_samples
from src.datasets.dataset_utils import get_dataloader
from src.models import UNetSpatioTemporalConditionModel, ControlNetModel
from src.pipelines import StableVideoControlNullModelPipeline
if not is_wandb_available():
warnings.warn("Make sure to install wandb if you want to use it for logging during training.")
else:
import wandb
logger = get_logger(__name__, log_level="INFO")
def get_latest_checkpoint(checkpoint_dir):
dirs = os.listdir(checkpoint_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
return path
def load_model_hook(models, input_dir):
for _ in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
if isinstance(model, UNetSpatioTemporalConditionModel):
load_model = UNetSpatioTemporalConditionModel.from_pretrained(input_dir, subfolder="unet")
elif isinstance(model, ControlNetModel):
load_model = ControlNetModel.from_pretrained(input_dir, subfolder="control_net")
else:
raise Exception("Only UNetSpatioTemporalConditionModel and ControlNetModel are supported for loading.")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_models_accelerate(models, weights, output_dir, vae, feature_extractor, image_encoder, noise_scheduler):
for i, model in enumerate(models):
if isinstance(model, UNetSpatioTemporalConditionModel):
model.save_pretrained(os.path.join(output_dir, "unet"), safe_serialization=False)
elif isinstance(model, ControlNetModel):
model.save_pretrained(os.path.join(output_dir, "control_net"), safe_serialization=False)
else:
raise Exception("Only UNetSpatioTemporalConditionModel and ControlNetModel are supported for saving.")
# Also save other (frozen) components, just so they are found in the same checkpoint
# vae.save_pretrained(os.path.join(output_dir, "vae"), safe_serialization=False)
# feature_extractor.save_pretrained(os.path.join(output_dir, "feature_extractor"), safe_serialization=False)
# image_encoder.save_pretrained(os.path.join(output_dir, "image_encoder"), safe_serialization=False)
# noise_scheduler.save_pretrained(os.path.join(output_dir, "scheduler"), safe_serialization=False)
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
class TrainVideoControlnet:
def __init__(self, args):
self.args = args
def get_accelerator(self):
logging_dir = Path(self.args.output_dir, self.args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
mixed_precision=self.args.mixed_precision,
log_with=self.args.report_to,
project_config=accelerator_project_config,
)
self.accelerator = accelerator
def log_setup(self):
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(self.accelerator.state, main_process_only=False)
# If passed along, set the training seed now.
if self.args.seed is not None:
set_seed(self.args.seed)
# Handle the repository creation
if self.accelerator.is_main_process and self.args.output_dir is not None:
os.makedirs(self.args.output_dir, exist_ok=True)
def load_models_from_pretrained(self):
# Load scheduler, tokenizer and models.
unet = UNetSpatioTemporalConditionModel.from_pretrained(
self.args.pretrained_model_name_or_path,
subfolder="unet",
variant=None,
low_cpu_mem_usage=True,
num_frames=self.args.clip_length
)
# Pretrained (frozen) models from stable-video-diffusion
model_path = "stabilityai/stable-video-diffusion-img2vid-xt"
variant = "fp16"
vae = AutoencoderKLTemporalDecoder.from_pretrained(
model_path,
subfolder="vae",
revision=self.args.revision,
variant=variant
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
model_path,
subfolder="image_encoder",
revision=self.args.revision,
variant=variant
)
feature_extractor = CLIPImageProcessor.from_pretrained(
model_path,
subfolder="feature_extractor",
revision=self.args.revision
)
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
model_path,
subfolder="scheduler"
)
null_model_unet = UNetSpatioTemporalConditionModel.from_pretrained(
model_path,
subfolder="unet",
variant=None,
low_cpu_mem_usage=True,
num_frames=self.args.clip_length
)
return unet, vae, image_encoder, feature_extractor, noise_scheduler, null_model_unet
def get_dataloaders(self):
train_dataset, train_loader = get_dataloader(self.args.data_root, self.args.dataset_name, if_train=True, clip_length=self.args.clip_length,
batch_size=self.args.train_batch_size, num_workers=self.args.dataloader_num_workers, shuffle=True,
image_height=self.args.train_H, image_width=self.args.train_W,
non_overlapping_clips=self.args.non_overlapping_clips,
bbox_masking_prob=self.args.bbox_masking_prob
)
val_dataset, val_loader = get_dataloader(self.args.data_root, self.args.dataset_name, if_train=False, clip_length=self.args.clip_length,
batch_size=self.args.num_demo_samples, num_workers=self.args.dataloader_num_workers, shuffle=True,
image_height=self.args.train_H, image_width=self.args.train_W,
non_overlapping_clips=True,
)
# demo_samples = get_samples(val_loader, self.args.num_demo_samples, show_progress=True)
return train_dataset, train_loader, val_dataset, val_loader
def get_sigmas(self, timesteps, n_dim=5, dtype=torch.float32):
sigmas = self.noise_scheduler.sigmas.to(device=self.accelerator.device, dtype=dtype)
schedule_timesteps = self.noise_scheduler.timesteps.to(self.accelerator.device)
timesteps = timesteps.to(self.accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def setup_training(self):
self.get_accelerator()
self.log_setup()
self.last_save_time = 0
if self.accelerator.is_main_process:
self.last_save_time = time.time()
# Setup devices
accelerator_device = self.accelerator.device
unet, vae, image_encoder, feature_extractor, noise_scheduler, null_model_unet = self.load_models_from_pretrained()
# freeze parameters of models to save more memory
vae.requires_grad_(False)
image_encoder.requires_grad_(False)
unet.requires_grad_(False)
null_model_unet.requires_grad_(False)
# Load the model
assert self.args.train_H % 8 == 0 and self.args.train_W % 8 == 0
self.bbox_embedding_shape = (4, self.args.train_H // 8, self.args.train_W // 8)
ctrlnet = ControlNetModel.from_unet(unet, action_dim=5, bbox_embedding_shape=self.bbox_embedding_shape)
ctrlnet.requires_grad_(True)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
self.weight_dtype = self.train_weights_dtype = torch.float32
if self.accelerator.mixed_precision == "fp16":
self.weight_dtype = torch.float16
elif self.accelerator.mixed_precision == "bf16":
self.weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
vae.to(accelerator_device, dtype=self.weight_dtype)
image_encoder.to(accelerator_device, dtype=self.weight_dtype)
unet.to(accelerator_device, dtype=self.weight_dtype)
null_model_unet.to(accelerator_device, dtype=self.weight_dtype)
# `accelerate` 0.16.0 will have better support for customized saving
assert version.parse(accelerate.__version__) >= version.parse("0.16.0")
def save_model_hook(models, weights, output_dir):
if self.accelerator.is_main_process:
save_models_accelerate(models, weights, output_dir, vae, feature_extractor, image_encoder, noise_scheduler)
self.accelerator.register_save_state_pre_hook(save_model_hook)
self.accelerator.register_load_state_pre_hook(load_model_hook)
if self.args.enable_gradient_checkpointing:
unet.enable_gradient_checkpointing()
if self.args.scale_lr:
self.args.learning_rate = (self.args.learning_rate * self.args.gradient_accumulation_steps * self.args.per_gpu_batch_size * self.accelerator.num_processes)
optimizer = torch.optim.AdamW(
list(ctrlnet.parameters()), # Include all trainable parameters
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
weight_decay=self.args.adam_weight_decay,
eps=self.args.adam_epsilon,
)
train_dataset, train_loader, _, val_loader = self.get_dataloaders()
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_loader) / self.args.gradient_accumulation_steps)
if self.args.max_train_steps is None:
self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
self.args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=self.args.lr_warmup_steps * self.accelerator.num_processes,
num_training_steps=self.args.max_train_steps * self.accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
unet, ctrlnet, optimizer, train_loader, lr_scheduler, null_model_unet = self.accelerator.prepare(
unet, ctrlnet, optimizer, train_loader, lr_scheduler, null_model_unet
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_loader) / self.args.gradient_accumulation_steps)
if overrode_max_train_steps:
self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
self.args.num_train_epochs = math.ceil(self.args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if self.accelerator.is_main_process:
if not self.args.disable_wandb:
self.accelerator.init_trackers(self.args.project_name, config=vars(self.args), init_kwargs={"wandb": {"dir": self.args.output_dir, "name": self.args.run_name, "entity": self.args.wandb_entity}})
else:
print("WANDB LOGGING DISABLED")
self.unet = unet
self.ctrlnet = ctrlnet
self.vae = vae
self.image_encoder = image_encoder
self.feature_extractor = feature_extractor
self.noise_scheduler = noise_scheduler
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.lr_scheduler = lr_scheduler
self.train_dataset = train_dataset
self.null_model_unet = null_model_unet
def print_train_info(self):
total_batch_size = self.args.train_batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(self.train_dataset)}")
logger.info(f" Num Epochs = {self.args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {self.args.max_train_steps}")
logger.info(f" Number of processes = {self.accelerator.num_processes}")
def load_checkpoint(self):
self.first_epoch = 0
self.global_step = 0
if not self.args.resume_from_checkpoint:
self.initial_global_step = 0
return
if self.args.resume_from_checkpoint != "latest":
path = os.path.basename(self.args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
path = get_latest_checkpoint(self.args.output_dir)
if path is None:
self.accelerator.print(f"Checkpoint '{self.args.resume_from_checkpoint}' does not exist. Starting a new training run.")
self.args.resume_from_checkpoint = None
self.initial_global_step = 0
else:
self.accelerator.print(f"Resuming from checkpoint {path}")
self.accelerator.load_state(os.path.join(self.args.output_dir, path))
self.initial_global_step = self.global_step = int(path.split("-")[1])
# self.first_epoch = global_step // num_update_steps_per_epoch # Not calculating first epoch right when using multiple processes. Let's default to using more epochs
def unwrap_model(self, model):
model = self.accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def log_step(self, step_loss, train_loss):
logs = {"step_loss": step_loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]}
self.progress_bar.set_postfix(**logs)
# Checks if the accelerator has performed an optimization step behind the scenes
if self.accelerator.sync_gradients:
self.progress_bar.update(1)
self.global_step += 1
log_plot = {"train_loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0],}
if self.args.add_bbox_frame_conditioning:
log_plot["|attn_rz_weight|"] = self.unet.get_attention_rz_weight()
self.accelerator.log(log_plot, step=self.global_step)
train_loss = 0.0
def save_checkpoint(self):
# Checks if the accelerator has performed an optimization step behind the scenes (only checkpoint after the gradient accumulation steps)
if not self.accelerator.sync_gradients:
return
args = self.args
# Save if checkpointing step reached or job is about to expire
save_checkpoint_time = args.checkpointing_time > 0 and (time.time() - self.last_save_time >= args.checkpointing_time)
# Save if number of steps for checkpointing reached
save_checkpoint_steps = self.global_step % args.checkpointing_steps == 0 or save_checkpoint_time
if self.accelerator.is_main_process and (save_checkpoint_time or save_checkpoint_steps):
if save_checkpoint_time:
print("Saving checkpoint due to time. Time elapsed:", time.time() - self.last_save_time)
self.last_save_time = time.time()
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{self.global_step}")
self.accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
def run_inference_with_pipeline(self, pipeline, demo_samples, log_dict):
args = self.args
for sample_i, sample in tqdm(enumerate(demo_samples), desc="Validation", total=args.num_demo_samples):
action_type = sample['action_type'].unsqueeze(0)
frames = pipeline(sample['image_init'] if not args.generate_bbox else sample['bbox_init'],
cond_images=sample['bbox_images'].unsqueeze(0) if not args.generate_bbox else sample['gt_clip'].unsqueeze(0),
action_type=action_type,
height=self.train_dataset.resize_height, width=self.train_dataset.resize_width,
decode_chunk_size=8, motion_bucket_id=127, fps=args.fps,
num_inference_steps=args.num_inference_steps,
num_frames=args.clip_length,
control_condition_scale=args.conditioning_scale,
min_guidance_scale=args.min_guidance_scale,
max_guidance_scale=args.max_guidance_scale,
noise_aug_strength=args.noise_aug_strength,
generator=self.generator, output_type='pt').frames[0]
#frames = F.interpolate(frames, (train_dataset.orig_H, train_dataset.orig_W)).detach().cpu().numpy()*255
frames = frames.detach().cpu().numpy()*255
frames = frames.astype(np.uint8)
log_dict["generated_videos"].append(wandb.Video(frames, fps=args.fps))
if sample.get('bbox_images_np') is not None:
# Add action type text to ground truth bounding box frames
bbox_frames = sample['bbox_images_np'].copy()
action_id = action_type.item()
action_name = {0: "Normal", 1: "Ego", 2: "Ego/Veh", 3: "Veh", 4: "Veh/Veh"}
action_text = f"Action: {action_name[action_id]} ({action_id})"
for i in range(bbox_frames.shape[0]):
# Convert numpy array to PIL Image
frame = Image.fromarray(bbox_frames[i].transpose(1, 2, 0))
draw = ImageDraw.Draw(frame)
# Add text in top right corner
text_position = (frame.width - 10, 10) # 10 pixels from top, 10 pixels from right
draw.text(text_position, action_text, fill=(255, 255, 255), anchor="ra")
# Add video name in top left corner
text_position = (10, 10) # 10 pixels from top, 10 pixels from right
draw.text(text_position, sample['vid_name'], fill=(255, 255, 255), anchor="la")
# Convert back to numpy array
bbox_frames[i] = np.array(frame).transpose(2, 0, 1)
log_dict["gt_bbox_frames"].append(wandb.Video(bbox_frames, fps=args.fps))
log_dict["gt_videos"].append(wandb.Video(sample['gt_clip_np'], fps=args.fps))
# frame_bboxes = wandb_frames_with_bbox(frames, sample['objects_tensors'], (train_dataset.orig_W, train_dataset.orig_H))
# log_dict["frames_with_bboxes_{}".format(sample_i)] = frame_bboxes
return log_dict
def load_pipeline(self):
# NOTE: Pipeline used for inference at validation step, can change for different pipelines
# Compatibility with pretrained models from ctrlv
import src
import src.models.unet_spatio_temporal_condition as unet_module
sys.modules['ctrlv'] = src
sys.modules['ctrlv.models'] = src.models
sys.modules['ctrlv.models.unet_spatio_temporal_condition'] = unet_module
pipeline = StableVideoControlNullModelPipeline.from_pretrained(self.args.pretrained_model_name_or_path,
unet=self.unwrap_model(self.unet),
image_encoder=self.unwrap_model(self.image_encoder),
vae=self.unwrap_model(self.vae),
controlnet=self.unwrap_model(self.ctrlnet),
null_model=self.unwrap_model(self.null_model_unet),
feature_extractor=self.feature_extractor,
revision=self.args.revision,
variant=self.args.variant,
torch_dtype=self.weight_dtype,
)
pipeline = pipeline.to(self.accelerator.device)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def validation_step(self, save_pipeline=False):
logger.info("Running validation... ")
log_dict = defaultdict(list)
with torch.autocast(str(self.accelerator.device).replace(":0", ""), enabled=self.accelerator.mixed_precision == "fp16"):
pipeline = self.load_pipeline()
if self.demo_samples is None:
self.demo_samples = get_samples(self.val_loader, self.args.num_demo_samples, show_progress=True)
self.ctrlnet.eval()
log_dict = self.run_inference_with_pipeline(pipeline, self.demo_samples, log_dict)
for tracker in self.accelerator.trackers:
if tracker.name == "wandb":
tracker.log(log_dict)
if save_pipeline:
pipeline.save_pretrained(self.args.output_dir)
del pipeline, log_dict
# torch.cuda.empty_cache()
logger.info("Validation complete. ")
def train_step(self, batch):
# Aliases
args = self.args
accelerator = self.accelerator
accelerator_device = self.accelerator.device
ctrlnet, unet, vae, image_encoder, feature_extractor = self.ctrlnet, self.unet, self.vae, self.image_encoder, self.feature_extractor
optimizer, noise_scheduler, lr_scheduler = self.optimizer, self.noise_scheduler, self.lr_scheduler
weight_dtype = self.weight_dtype
train_weights_dtype = self.train_weights_dtype
train_loss = 0.0
self.ctrlnet.train()
with accelerator.accumulate(self.ctrlnet):
# Forward pass
batch_size, video_length = batch['clips'].shape[0], batch['clips'].shape[1]
initial_images = batch['clips'][:,0,:,:,:] if not self.args.generate_bbox else batch['bbox_images'][:,0,:,:,:] # only use the first frame
# check device
if vae.device != accelerator_device:
vae.to(accelerator_device)
image_encoder.to(accelerator_device)
initial_images.to(accelerator_device)
# Encode input image
encoder_hidden_states = encode_video_image(initial_images, feature_extractor, weight_dtype, image_encoder).unsqueeze(1)
encoder_hidden_states = encoder_hidden_states.to(dtype=train_weights_dtype).to(accelerator_device)
# Encode input image using VAE
conditional_latents = vae.encode(initial_images.to(accelerator_device).to(weight_dtype)).latent_dist.sample()
conditional_latents = conditional_latents.to(accelerator_device).to(train_weights_dtype)
# Encode bbox image using VAE
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
bbox_frames = rearrange(batch['bbox_images'] if not args.generate_bbox else batch['clips'], 'b f c h w -> (b f) c h w').to(vae.device).to(weight_dtype)
# Get the selected option from the batch (This is the accident type flag)
if self.args.use_action_conditioning:
action_type = batch['action_type'] # Shape: [batch_size, 1]
else:
action_type = None
# Encode using VAE (now with standard 3 channels)
bbox_em = vae.encode(bbox_frames).latent_dist.sample()
bbox_em = rearrange(bbox_em, '(b f) c h w -> b f c h w', f=video_length).to(accelerator_device).to(train_weights_dtype)
# Encode clip frames using VAE
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
frames = rearrange(batch['clips'] if not args.generate_bbox else batch['bbox_images'], 'b f c h w -> (b f) c h w').to(vae.device).to(weight_dtype)
latents = vae.encode(frames).latent_dist.sample()
latents = rearrange(latents, '(b f) c h w -> b f c h w', f=video_length).to(accelerator_device).to(train_weights_dtype)
target_latents = latents = latents * vae.config.scaling_factor
del batch, frames
noise = torch.randn_like(latents)
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=noise_scheduler.timesteps.device).long()
timesteps = noise_scheduler.timesteps[indices].to(accelerator_device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Scale the noisy latents for the UNet
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# inp_noisy_latents = noise_scheduler.scale_model_input(noisy_latents, timesteps)
inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)
added_time_ids = get_add_time_ids(
fps=args.fps-1,
motion_bucket_id=127,
noise_aug_strength=args.noise_aug_strength,
dtype=weight_dtype,
batch_size=batch_size,
unet=unet
).to(accelerator_device)
# Conditioning dropout to support classifier-free guidance during inference. For more details
# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
# Addapted from https://github.com/huggingface/diffusers/blob/0d2d424fbef933e4b81bea20a660ee6fc8b75ab0/docs/source/en/training/instructpix2pix.md
if args.conditioning_dropout_prob is not None:
random_p = torch.rand(batch_size, device=accelerator_device, generator=self.generator)
# Sample masks for the edit prompts.
prompt_mask = random_p < 2 * args.conditioning_dropout_prob
prompt_mask = prompt_mask.reshape(batch_size, 1, 1)
# Final text conditioning (initial image CLIP embedding)
null_conditioning = torch.zeros_like(encoder_hidden_states)
encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
# Sample masks for the original images.
image_mask_dtype = conditional_latents.dtype
image_mask = 1 - (
(random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
* (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
)
image_mask = image_mask.reshape(batch_size, 1, 1, 1)
# Final image conditioning.
conditional_latents = image_mask * conditional_latents
# Bbox conditioning masking
if args.contiguous_bbox_masking_prob is not None and args.contiguous_bbox_masking_prob > 0:
random_p = torch.rand(batch_size, device=accelerator_device, generator=self.generator)
if args.contiguous_bbox_masking_start_ratio is not None and args.contiguous_bbox_masking_start_ratio > 0:
# Among the masked samples, randomly select some to mask from the start of the video (and the rest from the end)
random_threshold = args.contiguous_bbox_masking_prob * args.contiguous_bbox_masking_start_ratio
sample_mask_start = (random_p < random_threshold).view(batch_size, 1, 1, 1, 1)
sample_mask_end = ((random_p >= random_threshold) & (random_p < args.contiguous_bbox_masking_prob)).view(batch_size, 1, 1, 1, 1)
min_bbox_mask_idx_start, max_bbox_mask_idx_start = 0, (video_length + 1) # TODO: Determine schedule (decrease min mask idx over time)
bbox_mask_idx_start = torch.randint(min_bbox_mask_idx_start, max_bbox_mask_idx_start, (batch_size,), device=accelerator_device)
mask_cond_start = bbox_mask_idx_start > torch.arange(args.clip_length, device=accelerator_device).view(1, args.clip_length, 1, 1, 1)
bbox_em = torch.where(mask_cond_start & sample_mask_start, self.unwrap_model(self.ctrlnet).bbox_null_embedding, bbox_em)
else:
sample_mask_end = (random_p < args.contiguous_bbox_masking_prob).view(batch_size, 1, 1, 1, 1)
min_bbox_mask_idx, max_bbox_mask_idx = 0, (video_length + 1) # TODO: Determine schedule (decrease min mask idx over time)
bbox_mask_idx = torch.randint(min_bbox_mask_idx, max_bbox_mask_idx, (batch_size,), device=accelerator_device)
mask_cond = bbox_mask_idx <= torch.arange(args.clip_length, device=accelerator_device).view(1, args.clip_length, 1, 1, 1)
bbox_em = torch.where(mask_cond & sample_mask_end, self.unwrap_model(self.ctrlnet).bbox_null_embedding, bbox_em)
# Concatenate the `original_image_embeds` with the `noisy_latents`.
# conditional_latents = unet.encode_bbox_frame(conditional_latents, None)
conditional_latents = conditional_latents.unsqueeze(1).repeat(1, self.args.clip_length, 1, 1, 1)
concatenated_noisy_latents = torch.cat([inp_noisy_latents, conditional_latents], dim=2)
added_time_ids = added_time_ids.to(dtype=train_weights_dtype)
down_block_additional_residuals, mid_block_additional_residuals = ctrlnet(
concatenated_noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
added_time_ids=added_time_ids,
control_cond=bbox_em,
action_type=action_type,
conditioning_scale=args.conditioning_scale,
return_dict=False,
)
if args.empty_cuda_cache:
torch.cuda.empty_cache()
model_pred = unet(sample=concatenated_noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
added_time_ids=added_time_ids,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residuals=mid_block_additional_residuals,).sample
# Denoise the latents
c_out = -sigmas / ((sigmas**2 + 1)**0.5)
c_skip = 1 / (sigmas**2 + 1)
denoised_latents = model_pred * c_out + c_skip * noisy_latents
weighting = (1 + sigmas ** 2) * (sigmas**-2.0)
# MSE loss
step_loss = torch.mean((weighting.float() * (denoised_latents.float() - target_latents.float()) ** 2).reshape(target_latents.shape[0], -1), dim=1,)
step_loss = step_loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(step_loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(step_loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
return step_loss, train_loss
def train(self):
"""
Main training loop
"""
self.print_train_info()
# Potentially load in the weights and states from a previous save
self.load_checkpoint()
args = self.args
accelerator = self.accelerator
self.progress_bar = tqdm(
range(0, args.max_train_steps),
initial=self.initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
self.generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
self.demo_samples = None # Lazy load for samples
test_valid = False
for _ in range(self.first_epoch, args.num_train_epochs):
for _, batch in enumerate(self.train_loader):
# Check for for validation
if accelerator.sync_gradients:
if test_valid or (self.global_step % args.validation_steps == 0 and self.global_step != 0) or (args.val_on_first_step and self.global_step == 0):
if accelerator.is_main_process:
self.validation_step()
accelerator.wait_for_everyone()
test_valid = False
# Training step
step_loss, train_loss = self.train_step(batch)
# Log info
self.log_step(step_loss, train_loss)
# Potentially save checkpoint
self.save_checkpoint()
# if args.empty_cuda_cache:
# torch.cuda.empty_cache()
# if global_step >= args.max_train_steps:
# break
accelerator.wait_for_everyone()
# Run a final round of inference
if accelerator.is_main_process:
logger.info("Running inference before terminating...")
self.validation_step()
logging.info("Finished training.")
accelerator.end_training()
def main():
args = parse_args()
try:
train_controlnet = TrainVideoControlnet(args)
train_controlnet.setup_training() # Load models, setup logging and define training config
train_controlnet.train() # Train!
except KeyboardInterrupt:
if hasattr(train_controlnet, "accelerator"):
train_controlnet.accelerator.end_training()
if is_wandb_available():
wandb.finish()
print("Keyboard interrupt: shutdown requested... Exiting.")
exit()
except Exception:
import sys, traceback
if is_wandb_available():
wandb.finish()
traceback.print_exc(file=sys.stdout)
sys.exit(0)
if __name__ == '__main__':
main()