import json import multiprocessing as mp import os import random import re import sys import time from contextlib import contextmanager from glob import glob from pathlib import Path from typing import Any, Dict, Tuple, cast import click import numpy as np from omegaconf import DictConfig, ListConfig, OmegaConf from safetensors.torch import save_file import torch from torch import Tensor from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict import torch.nn.functional as F from tqdm import tqdm torch._dynamo.config.cache_size_limit = 32 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.use_deterministic_algorithms(False) import genmo.mochi_preview.dit.joint_model.lora as lora from genmo.lib.progress import progress_bar from genmo.lib.utils import Timer, save_video from genmo.mochi_preview.pipelines import ( DecoderModelFactory, DitModelFactory, ModelFactory, T5ModelFactory, cast_dit, compute_packed_indices, get_conditioning, linear_quadratic_schedule, # used in eval'd Python code in lora.yaml load_to_cpu, move_to_device, sample_model, t5_tokenizer, ) from genmo.mochi_preview.vae.latent_dist import LatentDistribution from genmo.mochi_preview.vae.models import decode_latents_tiled_spatial sys.path.append("..") from dataset import LatentEmbedDataset class MochiTorchRunEvalPipeline: def __init__( self, *, device_id, dit, text_encoder_factory: ModelFactory, decoder_factory: ModelFactory, ): self.device = torch.device(f"cuda:{device_id}") self.tokenizer = t5_tokenizer() t = Timer() self.dit = dit with t("load_text_encoder"): self.text_encoder = text_encoder_factory.get_model( local_rank=0, world_size=1, device_id="cpu", ) with t("load_vae"): self.decoder = decoder_factory.get_model(local_rank=0, device_id="cpu", world_size=1) t.print_stats() # type: ignore def __call__(self, prompt, save_path, **kwargs): with progress_bar(type="tqdm", enabled=True), torch.inference_mode(): # Encode prompt with T5 XXL. with move_to_device(self.text_encoder, self.device, enabled=True): conditioning = get_conditioning( self.tokenizer, self.text_encoder, self.device, batch_inputs=False, prompt=prompt, negative_prompt="", ) # Sample video latents from Mochi. with move_to_device(self.dit, self.device, enabled=True): latents = sample_model(self.device, self.dit, conditioning, **kwargs) # Decode video latents to frames. with move_to_device(self.decoder, self.device, enabled=True): frames = decode_latents_tiled_spatial( self.decoder, latents, num_tiles_w=2, num_tiles_h=2, overlap=8) frames = frames.cpu().numpy() # b t h w c assert isinstance(frames, np.ndarray) save_video(frames[0], save_path) def map_to_device(x, device: torch.device): if isinstance(x, dict): return {k: map_to_device(v, device) for k, v in x.items()} elif isinstance(x, list): return [map_to_device(y, device) for y in x] elif isinstance(x, tuple): return tuple(map_to_device(y, device) for y in x) elif isinstance(x, torch.Tensor): return x.to(device, non_blocking=True) else: return x EPOCH_IDX = 0 def infinite_dl(dl): global EPOCH_IDX while True: EPOCH_IDX += 1 for batch in dl: yield batch @contextmanager def timer(description="Task", enabled=True): if enabled: start = time.perf_counter() try: yield finally: if enabled: elapsed = time.perf_counter() - start # type: ignore print(f"{description} took {elapsed:.4f} seconds") def get_cosine_annealing_lr_scheduler( optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, ): def lr_lambda(step): if step < warmup_steps: return float(step) / float(max(1, warmup_steps)) else: return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps))) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) @click.command() @click.option("--config-path", type=click.Path(exists=True), required=True, help="Path to YAML config file") def main(config_path): mp.set_start_method("spawn", force=True) cfg = cast(DictConfig, OmegaConf.load(config_path)) device_id = 0 device_str = f"cuda:0" device = torch.device(device_str) # Verify checkpoint path exists checkpoint_path = Path(cfg.init_checkpoint_path) assert checkpoint_path.exists(), f"Checkpoint file not found: {checkpoint_path}" # Create checkpoint directory if it doesn't exist checkpoint_dir = Path(cfg.checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) # Get step number from checkpoint filename pattern = r"model_(\d+)\.(lora|checkpoint)\.(safetensors|pt)" match = re.search(pattern, str(checkpoint_path)) if match: start_step_num = int(match.group(1)) opt_path = str(checkpoint_path).replace("model_", "optimizer_") else: start_step_num = 0 opt_path = "" print( f"model={checkpoint_path}, optimizer={opt_path}, start_step_num={start_step_num}" ) wandb_run = None sample_prompts = cfg.sample.prompts train_vids = list(sorted(glob(f"{cfg.train_data_dir}/*.mp4"))) train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")] print(f"Found {len(train_vids)} training videos in {cfg.train_data_dir}") assert len(train_vids) > 0, f"No training data found in {cfg.train_data_dir}" if cfg.single_video_mode: train_vids = train_vids[:1] sample_prompts = [Path(train_vids[0]).with_suffix(".txt").read_text()] print(f"Training on video: {train_vids[0]}") train_dataset = LatentEmbedDataset( train_vids, repeat=1_000 if cfg.single_video_mode else 1, ) train_dl = torch.utils.data.DataLoader( train_dataset, batch_size=None, num_workers=4, shuffle=True, pin_memory=True, ) train_dl_iter = infinite_dl(train_dl) if cfg.get("wandb"): import wandb wandb_run = wandb.init( project=cfg.wandb.project, name=f"{cfg.wandb.name}-{int(time.time())}", config=OmegaConf.to_container(cfg), # type: ignore ) print(f"🚀 Weights & Biases run URL: {wandb_run.get_url()}") print("Loading model") patch_model_fns = [] model_kwargs = {} is_lora = cfg.model.type == "lora" print(f"Training type: {'LoRA' if is_lora else 'Full'}") if is_lora: def mark_lora_params(m): lora.mark_only_lora_as_trainable(m, bias="none") return m patch_model_fns.append(mark_lora_params) model_kwargs = dict(**cfg.model.kwargs) # Replace ListConfig with list to allow serialization to JSON. for k, v in model_kwargs.items(): if isinstance(v, ListConfig): model_kwargs[k] = list(v) if cfg.training.get("model_dtype"): assert cfg.training.model_dtype == "bf16", f"Only bf16 is supported" patch_model_fns.append(lambda m: cast_dit(m, torch.bfloat16)) model = ( DitModelFactory( model_path=str(checkpoint_path), model_dtype="bf16", attention_mode=cfg.attention_mode ).get_model( local_rank=0, device_id=device_id, model_kwargs=model_kwargs, patch_model_fns=patch_model_fns, world_size=1, strict_load=not is_lora, fast_init=not is_lora, # fast_init not supported for LoRA (please someone fix this !!!) ) .train() # calling train() makes sure LoRA weights are not merged ) optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer) if os.path.exists(opt_path): print("Loading optimizer") optimizer.load_state_dict(load_to_cpu(opt_path)) scheduler = get_cosine_annealing_lr_scheduler( optimizer, warmup_steps=cfg.training.warmup_steps, total_steps=cfg.training.num_steps ) print("Loading eval pipeline ...") eval_pipeline = MochiTorchRunEvalPipeline( device_id=device_id, dit=model, text_encoder_factory=T5ModelFactory(), decoder_factory=DecoderModelFactory(model_path=cfg.sample.decoder_path), ) def get_batch() -> Tuple[Dict[str, Any], Tensor, Tensor, Tensor]: nonlocal train_dl_iter batch = next(train_dl_iter) # type: ignore latent, embed = cast(Tuple[Dict[str, Any], Dict[str, Any]], batch) assert len(embed["y_feat"]) == 1 and len(embed["y_mask"]) == 1, f"Only batch size 1 is supported" ldist = LatentDistribution(latent["mean"], latent["logvar"]) z = ldist.sample() assert torch.isfinite(z).all() assert z.shape[0] == 1, f"Only batch size 1 is supported" eps = torch.randn_like(z) sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32) if random.random() < cfg.training.caption_dropout: embed["y_mask"][0].zero_() embed["y_feat"][0].zero_() return embed, z, eps, sigma pbar = tqdm( range(start_step_num, cfg.training.num_steps), total=cfg.training.num_steps, initial=start_step_num, ) for step in pbar: if cfg.sample.interval and step % cfg.sample.interval == 0 and step > 0: sample_dir = Path(cfg.sample.output_dir) sample_dir.mkdir(exist_ok=True) model.eval() for eval_idx, prompt in enumerate(sample_prompts): save_path = sample_dir / f"{eval_idx}_{step}.mp4" if save_path.exists(): print(f"Skipping {save_path} as it already exists") continue sample_kwargs = { k.removesuffix("_python_code"): (eval(v) if k.endswith("_python_code") else v) for k, v in cfg.sample.kwargs.items() } eval_pipeline( prompt=prompt, save_path=str(save_path), seed=cfg.sample.seed + eval_idx, **sample_kwargs, ) Path(sample_dir / f"{eval_idx}_{step}.txt").write_text(prompt) model.train() if cfg.training.save_interval and step > 0 and step % cfg.training.save_interval == 0: with timer("get_state_dict"): if is_lora: model_sd = lora.lora_state_dict(model, bias="none") else: # NOTE: Not saving optimizer state dict to save space. model_sd, _optimizer_sd = get_state_dict( model, [], options=StateDictOptions(cpu_offload=True, full_state_dict=True) ) checkpoint_filename = f"model_{step}.{'lora' if is_lora else 'checkpoint'}.pt" save_path = checkpoint_dir / checkpoint_filename if cfg.training.get("save_safetensors", True): save_path = save_path.with_suffix(".safetensors") save_file( model_sd, save_path, # `safetensors` only supports string-to-string metadata, # so we serialize the kwargs to a JSON string. metadata=dict(kwargs=json.dumps(model_kwargs)), ) else: torch.save(model_sd, save_path) with torch.no_grad(), timer("load_batch", enabled=False): batch = get_batch() embed, z, eps, sigma = map_to_device(batch, device) embed = cast(Dict[str, Any], embed) num_latent_toks = np.prod(z.shape[-3:]) indices = compute_packed_indices(device, cast(Tensor, embed["y_mask"][0]), int(num_latent_toks)) sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1] z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps ut = z - eps with torch.autocast("cuda", dtype=torch.bfloat16): preds = model( x=z_sigma, sigma=sigma, packed_indices=indices, **embed, num_ff_checkpoint=cfg.training.num_ff_checkpoint, num_qkv_checkpoint=cfg.training.num_qkv_checkpoint, ) assert preds.shape == z.shape loss = F.mse_loss(preds.float(), ut.float()) loss.backward() log_kwargs = { "train/loss": loss.item(), "train/epoch": EPOCH_IDX, "train/lr": scheduler.get_last_lr()[0], } if cfg.training.get("grad_clip"): assert not is_lora, "Gradient clipping not supported for LoRA" gnorm_before_clip = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=cfg.training.grad_clip) log_kwargs["train/gnorm"] = gnorm_before_clip.item() pbar.set_postfix(**log_kwargs) if wandb_run: wandb_run.log(log_kwargs, step=step) optimizer.step() scheduler.step() optimizer.zero_grad() if __name__ == "__main__": main()