"""Fine-tuning script for Stable Video Diffusion for image2video with support for LoRA."""
import logging
import math
import os
import shutil
from glob import glob
from pathlib import Path
from PIL import Image

import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint

from einops import rearrange
import transformers
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from packaging import version
from tqdm.auto import tqdm
import copy

import diffusers
from diffusers import AutoencoderKLTemporalDecoder
from diffusers import  UNetSpatioTemporalConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing


from custom_diffusers.pipelines.pipeline_stable_video_diffusion_with_ref_attnmap import StableVideoDiffusionWithRefAttnMapPipeline
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from attn_ctrl.attention_control import (AttentionStore, 
                                         register_temporal_self_attention_control, 
                                         register_temporal_self_attention_flip_control,
)
from utils.parse_args import parse_args
from dataset.stable_video_dataset import StableVideoDataset

logger = get_logger(__name__, log_level="INFO")

def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
    """Draws samples from an lognormal distribution."""
    u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
    return torch.distributions.Normal(loc, scale).icdf(u).exp()

def main():
    args = parse_args()
    
    logging_dir = Path(args.output_dir, args.logging_dir)

    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )
    if args.report_to == "wandb":
        if not is_wandb_available():
            raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
        import wandb

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

    # Load scheduler, tokenizer and models.
    noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    feature_extractor = CLIPImageProcessor.from_pretrained(args.pretrained_model_name_or_path, subfolder="feature_extractor")
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="image_encoder", variant=args.variant
    )
    vae = AutoencoderKLTemporalDecoder.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="vae", variant=args.variant
    )
    unet = UNetSpatioTemporalConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", low_cpu_mem_usage=True, variant=args.variant
    )
    ref_unet = copy.deepcopy(unet)

    # register customized attn processors
    controller_ref = AttentionStore()
    register_temporal_self_attention_control(ref_unet, controller_ref)

    controller = AttentionStore()
    register_temporal_self_attention_flip_control(unet, controller, controller_ref)

    # freeze parameters of models to save more memory
    ref_unet.requires_grad_(False)
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    image_encoder.requires_grad_(False)

    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    
    # Move unet, vae and image_encoder to device and cast to weight_dtype
    # unet.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)
    image_encoder.to(accelerator.device, dtype=weight_dtype)
    ref_unet.to(accelerator.device, dtype=weight_dtype)

    unet_train_params_list = []
    # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
    for name, para in unet.named_parameters():
        if 'temporal_transformer_blocks.0.attn1.to_v.weight' in name or 'temporal_transformer_blocks.0.attn1.to_out.0.weight' in name:
            unet_train_params_list.append(para)
            para.requires_grad = True
        else:
            para.requires_grad = False
    

    if args.mixed_precision == "fp16":
        # only upcast trainable parameters into fp32
        cast_training_params(unet, dtype=torch.float32)

    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                logger.warn(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

     # `accelerate` 0.16.0 will have better support for customized saving
    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
        def save_model_hook(models, weights, output_dir):
            if accelerator.is_main_process:
                for i, model in enumerate(models):
                    model.save_pretrained(os.path.join(output_dir, "unet"))

                    # make sure to pop weight so that corresponding model is not saved again
                    weights.pop()

        def load_model_hook(models, input_dir):
            for _ in range(len(models)):
                # pop models so that they are not loaded again
                model = models.pop()

                # load diffusers style into model
                load_model = UNetSpatioTemporalConditionModel.from_pretrained(input_dir, subfolder="unet")
                model.register_to_config(**load_model.config)

                model.load_state_dict(load_model.state_dict())
                del load_model

        accelerator.register_save_state_pre_hook(save_model_hook)
        accelerator.register_load_state_pre_hook(load_model_hook)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    if accelerator.is_main_process:
        rec_txt1 = open('frozen_param.txt', 'w')
        rec_txt2 = open('train_param.txt', 'w')
        for name, para in unet.named_parameters():
            if para.requires_grad is False:
                rec_txt1.write(f'{name}\n')
            else:
                rec_txt2.write(f'{name}\n')
        rec_txt1.close()
        rec_txt2.close()

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
        )

    # Initialize the optimizer
    optimizer = torch.optim.AdamW(
        unet_train_params_list,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model._orig_mod if is_compiled_module(model) else model
        return model

    train_dataset = StableVideoDataset(video_data_dir=args.train_data_dir, 
                                       max_num_videos=args.max_train_samples, 
                                       num_frames=args.num_frames,
                                       is_reverse_video=True,
                                       double_sampling_rate=args.double_sampling_rate)
    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        conditions = torch.stack([example["conditions"] for example in examples])
        conditions =conditions.to(memory_format=torch.contiguous_format).float()
        return {"pixel_values": pixel_values, "conditions": conditions}

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
    )

    # Validation data
    if args.validation_data_dir is not None:
        validation_image_paths = sorted(glob(os.path.join(args.validation_data_dir, '*.png')))
        num_validation_images = min(args.num_validation_images, len(validation_image_paths))
        validation_image_paths = validation_image_paths[:num_validation_images]
        validation_images = [Image.open(image_path).convert('RGB').resize((1024, 576)) for image_path in validation_image_paths]

        
    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
        num_training_steps=args.max_train_steps * accelerator.num_processes,
    )

    # Prepare everything with our `accelerator`.
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        accelerator.init_trackers("image2video-reverse-fine-tune", config=vars(args))

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    global_step = 0
    first_epoch = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
            initial_global_step = 0
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch
    else:
        initial_global_step = 0

    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=initial_global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )
    
    # default motion param setting
    def _get_add_time_ids(
        dtype,
        batch_size,
        fps=6,
        motion_bucket_id=127,
        noise_aug_strength=0.02,  
    ):
        add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
        passed_add_embed_dim = unet.module.config.addition_time_embed_dim * \
            len(add_time_ids)
        expected_add_embed_dim = unet.module.add_embedding.linear_1.in_features
        assert (expected_add_embed_dim == passed_add_embed_dim)

        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
        add_time_ids = add_time_ids.repeat(batch_size, 1)
        return add_time_ids

    def compute_image_embeddings(image):
        image = _resize_with_antialiasing(image, (224, 224))
        image = (image + 1.0) / 2.0
        # Normalize the image with for CLIP input
        image = feature_extractor(
            images=image,
            do_normalize=True,
            do_center_crop=False,
            do_resize=False,
            do_rescale=False,
            return_tensors="pt",
        ).pixel_values
        
        image = image.to(accelerator.device).to(dtype=weight_dtype)
        image_embeddings = image_encoder(image).image_embeds
        image_embeddings = image_embeddings.unsqueeze(1)
        return image_embeddings

    noise_aug_strength = 0.02
    fps=7         
    for epoch in range(first_epoch, args.num_train_epochs):
        unet.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                # Get the image embedding for conditioning
                encoder_hidden_states = compute_image_embeddings(batch["conditions"])
                encoder_hidden_states_ref = compute_image_embeddings(batch["pixel_values"][:, -1])
                
                batch["conditions"] = batch["conditions"].to(accelerator.device).to(dtype=weight_dtype)
                batch["pixel_values"] = batch["pixel_values"].to(accelerator.device).to(dtype=weight_dtype)
        
                # Get the image latent for input condtioning
                noise =  torch.randn_like(batch["conditions"])
                conditions = batch["conditions"] + noise_aug_strength * noise
                conditions_latent = vae.encode(conditions).latent_dist.mode()
                conditions_latent = conditions_latent.unsqueeze(1).repeat(1, args.num_frames, 1, 1, 1)

                conditions_ref = batch["pixel_values"][:, -1] + noise_aug_strength * noise
                conditions_latent_ref = vae.encode(conditions_ref).latent_dist.mode()
                conditions_latent_ref = conditions_latent_ref.unsqueeze(1).repeat(1, args.num_frames, 1, 1, 1)

                # Convert frames to latent space
                pixel_values = rearrange(batch["pixel_values"], "b f c h w -> (b f) c h w")
                latents = vae.encode(pixel_values).latent_dist.sample()
                latents = latents * vae.config.scaling_factor
                latents = rearrange(latents, "(b f) c h w -> b f c h w", f=args.num_frames)
                latents_ref= torch.flip(latents, dims=(1,))

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                if args.noise_offset:
                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise
                    noise += args.noise_offset * torch.randn(
                        (latents.shape[0], latents.shape[1], latents.shape[2], 1, 1), device=latents.device
                    )

                bsz = latents.shape[0]
                # Sample a random timestep for each image
                # P_mean=0.7 P_std=1.6
                sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents.device)
                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                sigmas = sigmas[:, None, None, None, None]
                timesteps = torch.Tensor(
                    [0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device)
                
                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = latents + noise * sigmas
                noisy_latents_inp = noisy_latents / ((sigmas**2 + 1) ** 0.5)
                noisy_latents_inp = torch.cat([noisy_latents_inp, conditions_latent], dim=2)

                noisy_latents_ref = latents_ref + torch.flip(noise, dims=(1,)) * sigmas
                noisy_latents_ref_inp = noisy_latents_ref / ((sigmas**2 + 1) ** 0.5)
                noisy_latents_ref_inp = torch.cat([noisy_latents_ref_inp, conditions_latent_ref], dim=2)

                # Get the target for loss depending on the prediction type
                target = latents
                # Predict the noise residual and compute loss
                added_time_ids = _get_add_time_ids(encoder_hidden_states.dtype, bsz).to(accelerator.device)
                ref_model_pred = ref_unet(noisy_latents_ref_inp.to(weight_dtype), timesteps.to(weight_dtype),
                                encoder_hidden_states=encoder_hidden_states_ref, 
                                added_time_ids=added_time_ids, 
                                return_dict=False)[0]
                model_pred = unet(noisy_latents_inp, timesteps,
                                encoder_hidden_states=encoder_hidden_states, 
                                added_time_ids=added_time_ids, 
                                return_dict=False)[0] # v-prediction
                # Denoise the latents
                c_out = -sigmas / ((sigmas**2 + 1)**0.5)
                c_skip = 1 / (sigmas**2 + 1)
                denoised_latents = model_pred * c_out + c_skip * noisy_latents
                weighing = (1 + sigmas ** 2) * (sigmas**-2.0)

                 # MSE loss
                loss = torch.mean(
                        (weighing.float() * (denoised_latents.float() -
                        target.float()) ** 2).reshape(target.shape[0], -1),
                        dim=1,
                )
                loss = loss.mean()
                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps

                # Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = unet_train_params_list
                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0

                if global_step % args.checkpointing_steps == 0:
                    if accelerator.is_main_process:
                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                        if args.checkpoints_total_limit is not None:
                            checkpoints = os.listdir(args.output_dir)
                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                            if len(checkpoints) >= args.checkpoints_total_limit:
                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                                removing_checkpoints = checkpoints[0:num_to_remove]

                                logger.info(
                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                                )
                                logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                                for removing_checkpoint in removing_checkpoints:
                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                    shutil.rmtree(removing_checkpoint)

                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)

            if global_step >= args.max_train_steps:
                break

        if accelerator.is_main_process:
            if args.validation_data_dir is not None and epoch % args.validation_epochs == 0:
                logger.info(
                    f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
                    f" {args.validation_data_dir}."
                )
                # create pipeline
                pipeline = StableVideoDiffusionWithRefAttnMapPipeline.from_pretrained(
                    args.pretrained_model_name_or_path, 
                    scheduler=noise_scheduler,
                    unet=unwrap_model(unet),
                    variant=args.variant,
                    torch_dtype=weight_dtype, 
                )
                pipeline = pipeline.to(accelerator.device)
                pipeline.set_progress_bar_config(disable=True)

                # run inference
                generator = torch.Generator(device=accelerator.device)
                if args.seed is not None:
                    generator = generator.manual_seed(args.seed)
                videos = []
                with torch.cuda.amp.autocast():
                    for val_idx in range(num_validation_images):
                        val_img = validation_images[val_idx]
                        videos.append(
                            pipeline(ref_unet=ref_unet, image=val_img, ref_image=val_img, num_inference_steps=50, generator=generator, output_type='pt').frames[0]
                        )

                for tracker in accelerator.trackers:
                    if tracker.name == "tensorboard":
                        videos = torch.stack(videos)
                        tracker.writer.add_video("validation", videos, epoch, fps=fps)

                del pipeline
                torch.cuda.empty_cache()

    # Save the lora layers
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        unet = unet.to(torch.float32)

        unwrapped_unet = unwrap_model(unet)
        pipeline = StableVideoDiffusionWithRefAttnMapPipeline.from_pretrained(
                    args.pretrained_model_name_or_path,
                    scheduler=noise_scheduler,
                    unet=unwrapped_unet,
                    variant=args.variant,
                )
        pipeline.save_pretrained(args.output_dir)    
        # Final inference
        # Load previous pipeline
        if args.validation_data_dir is not None:
            pipeline = pipeline.to(accelerator.device)
            pipeline.torch_dtype = weight_dtype
            # run inference
            generator = torch.Generator(device=accelerator.device)
            if args.seed is not None:
                generator = generator.manual_seed(args.seed)
            videos = []
            with torch.cuda.amp.autocast():
                for val_idx in range(num_validation_images):
                    val_img = validation_images[val_idx]
                    videos.append(
                        pipeline(ref_unet=ref_unet, image=val_img, ref_image=val_img, num_inference_steps=50, generator=generator, output_type='pt').frames[0]
                    )


            for tracker in accelerator.trackers:
                if len(videos) != 0:
                    if tracker.name == "tensorboard":
                        videos = torch.stack(videos)
                        tracker.writer.add_video("validation", videos, epoch, fps=fps)
                    
    accelerator.end_training()


if __name__ == "__main__":
    main()