import argparse import math import os import sys current_path = os.path.abspath(__file__) father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".") sys.path.append((os.path.join(father_path, 'Next3d'))) from typing import Dict, Optional, Tuple from omegaconf import OmegaConf import torch import logging import torch.nn.functional as F import torch.utils.checkpoint from torch.utils.data import Dataset import inspect from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed import dnnlib from diffusers.optimization import get_scheduler from tqdm.auto import tqdm from vae.triplane_vae import AutoencoderKL, AutoencoderKLRollOut from vae.data.dataset_online_vae import TriplaneDataset from einops import rearrange from vae.utils.common_utils import instantiate_from_config from Next3d.training_avatar_texture.triplane_generation import TriPlaneGenerator import Next3d.legacy as legacy from torch_utils import misc import datetime logger = get_logger(__name__, log_level="INFO") def collate_fn(data): model_names = [example["data_model_name"] for example in data] zs = torch.cat([example["data_z"] for example in data], dim=0) verts = torch.cat([example["data_vert"] for example in data], dim=0) return { 'model_names': model_names, 'zs': zs, 'verts': verts } def rollout_fn(triplane): triplane = rearrange(triplane, "b c f h w -> b f c h w") b, f, c, h, w = triplane.shape triplane = triplane.permute(0, 2, 3, 1, 4).reshape(-1, c, h, f * w) return triplane def unrollout_fn(triplane): res = triplane.shape[-2] ch = triplane.shape[1] triplane = triplane.reshape(-1, ch // 3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, 3, ch, res, res) triplane = rearrange(triplane, "b f c h w -> b c f h w") return triplane def triplane_generate(G_model, z, conditioning_params, std, mean, truncation_psi=0.7, truncation_cutoff=14): w = G_model.mapping(z, conditioning_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) triplane = G_model.synthesis(w, noise_mode='const') triplane = (triplane - mean) / std return triplane def gan_model(gan_models, device, gan_model_base_dir): gan_model_dict = gan_models gan_model_load = {} for model_name in gan_model_dict.keys(): model_pkl = os.path.join(gan_model_base_dir, model_name + '.pkl') with dnnlib.util.open_url(model_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) misc.copy_params_and_buffers(G, G_new, require_all=True) G_new.neural_rendering_resolution = G.neural_rendering_resolution G_new.rendering_kwargs = G.rendering_kwargs gan_model_load[model_name] = G_new return gan_model_load def main(vae_config: str, gan_model_config: str, output_dir: str, std_dir: str, mean_dir: str, conditioning_params_dir: str, gan_model_base_dir: str, train_data: Dict, train_batch_size: int = 2, max_train_steps: int = 500, learning_rate: float = 3e-5, scale_lr: bool = False, lr_scheduler: str = "constant", lr_warmup_steps: int = 0, adam_beta1: float = 0.5, adam_beta2: float = 0.9, adam_weight_decay: float = 1e-2, adam_epsilon: float = 1e-08, max_grad_norm: float = 1.0, gradient_accumulation_steps: int = 1, gradient_checkpointing: bool = True, checkpointing_steps: int = 500, pretrained_model_path_zero123: str = None, resume_from_checkpoint: Optional[str] = None, mixed_precision: Optional[str] = "fp16", use_8bit_adam: bool = False, rollout: bool = False, enable_xformers_memory_efficient_attention: bool = True, seed: Optional[int] = None, ): *_, config = inspect.getargvalues(inspect.currentframe()) base_dir = output_dir accelerator = Accelerator( gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=mixed_precision, ) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) # If passed along, set the training seed now. if seed is not None: set_seed(seed) if accelerator.is_main_process: now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = os.path.join(output_dir, now) os.makedirs(output_dir, exist_ok=True) os.makedirs(f"{output_dir}/samples", exist_ok=True) os.makedirs(f"{output_dir}/inv_latents", exist_ok=True) OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) config_vae = OmegaConf.load(vae_config) if rollout: vae = AutoencoderKLRollOut(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8) else: vae = AutoencoderKL(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8) print(f"VAE total params = {len(list(vae.named_parameters()))} ") if 'perceptual_weight' in config_vae['lossconfig']['params'].keys(): config_vae['lossconfig']['params']['device'] = str(accelerator.device) loss_fn = instantiate_from_config(config_vae['lossconfig']) conditioning_params = torch.load(conditioning_params_dir).to(str(accelerator.device)) data_std = torch.load(std_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1) data_mean = torch.load(mean_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1) # define the gan model print("########## gan model load ##########") config_gan_model = OmegaConf.load(gan_model_config) gan_model_all = gan_model(config_gan_model['gan_models'], str(accelerator.device), gan_model_base_dir) print("########## gan model loaded ##########") if scale_lr: learning_rate = ( learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" ) optimizer_cls = bnb.optim.AdamW8bit else: optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( vae.parameters(), lr=learning_rate, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay, eps=adam_epsilon, ) train_dataset = TriplaneDataset(**train_data) # Preprocessing the dataset # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, collate_fn=collate_fn, shuffle=True, num_workers=2 ) lr_scheduler = get_scheduler( lr_scheduler, optimizer=optimizer, num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, num_training_steps=max_train_steps * gradient_accumulation_steps, ) vae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( vae, optimizer, train_dataloader, lr_scheduler ) weight_dtype = torch.float32 # Move text_encode and vae to gpu and cast to weight_dtype if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 vae.to(accelerator.device, dtype=weight_dtype) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) # Afterwards we recalculate our number of training epochs num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("trainvae", config=vars(args)) # Train! total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if resume_from_checkpoint: if resume_from_checkpoint != "latest": path = os.path.basename(resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] accelerator.print(f"Resuming from checkpoint {path}") if resume_from_checkpoint != "latest": accelerator.load_state(resume_from_checkpoint) else: accelerator.load_state(os.path.join(output_dir, path)) global_step = int(path.split("-")[1]) first_epoch = global_step // num_update_steps_per_epoch resume_step = global_step % num_update_steps_per_epoch else: all_final_training_dirs = [] dirs = os.listdir(base_dir) if len(dirs) != 0: dirs = [d for d in dirs if d.startswith("2024")] # specific years if len(dirs) != 0: base_resume_paths = [os.path.join(base_dir, d) for d in dirs] for base_resume_path in base_resume_paths: checkpoint_file_names = os.listdir(base_resume_path) checkpoint_file_names = [d for d in checkpoint_file_names if d.startswith("checkpoint")] if len(checkpoint_file_names) != 0: for checkpoint_file_name in checkpoint_file_names: final_training_dir = os.path.join(base_resume_path, checkpoint_file_name) all_final_training_dirs.append(final_training_dir) if len(all_final_training_dirs) != 0: sorted_all_final_training_dirs = sorted(all_final_training_dirs, key=lambda x: int(x.split("-")[1])) latest_dir = sorted_all_final_training_dirs[-1] path = os.path.basename( latest_dir) accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(latest_dir) global_step = int(path.split("-")[1]) first_epoch = global_step // num_update_steps_per_epoch resume_step = global_step % num_update_steps_per_epoch else: accelerator.print(f"Training from start") else: accelerator.print(f"Training from start") else: accelerator.print(f"Training from start") # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") for epoch in range(first_epoch, num_train_epochs): vae.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): # if resume_from_checkpoint and epoch == first_epoch and step < resume_step: # print(epoch) # print(first_epoch) # print(step) # if step % gradient_accumulation_steps == 0: # progress_bar.update(1) # continue with accelerator.accumulate(vae): # Convert images to latent space z_values = batch["zs"].to(weight_dtype) model_names = batch["model_names"] triplane_values = [] with torch.no_grad(): for z_id in range(z_values.shape[0]): z_value = z_values[z_id].unsqueeze(0) model_name = model_names[z_id] triplane_value = triplane_generate(gan_model_all[model_name], z_value, conditioning_params, data_std, data_mean) triplane_values.append(triplane_value) triplane_values = torch.cat(triplane_values, dim=0) vert_values = batch["verts"].to(weight_dtype) triplane_values = rearrange(triplane_values, "b f c h w -> b c f h w") if rollout: triplane_values_roll = rollout_fn(triplane_values.clone()) reconstructions, posterior = vae(triplane_values_roll) reconstructions_unroll = unrollout_fn(reconstructions) loss, log_dict_ae = loss_fn(triplane_values, reconstructions_unroll, posterior, vert_values, split="train") else: reconstructions, posterior = vae(triplane_values) loss, log_dict_ae = loss_fn(triplane_values, reconstructions, posterior, vert_values, split="train") accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(vae.parameters(), max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 if global_step % checkpointing_steps == 0: if accelerator.is_main_process: save_path = os.path.join(output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") # logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = log_dict_ae progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= max_train_steps: break accelerator.wait_for_everyone() accelerator.end_training() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/triplane_vae.yaml") args = parser.parse_args() main(**OmegaConf.load(args.config))