diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..debd8f805028597a2a924b185019af58dc2c90ba --- /dev/null +++ b/app.py @@ -0,0 +1,170 @@ +# vae: +# class_path: src.models.vae.LatentVAE +# init_args: +# precompute: true +# weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ +# denoiser: +# class_path: src.models.denoiser.decoupled_improved_dit.DDT +# init_args: +# in_channels: 4 +# patch_size: 2 +# num_groups: 16 +# hidden_size: &hidden_dim 1152 +# num_blocks: 28 +# num_encoder_blocks: 22 +# num_classes: 1000 +# conditioner: +# class_path: src.models.conditioner.LabelConditioner +# init_args: +# null_class: 1000 +# diffusion_sampler: +# class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler +# init_args: +# num_steps: 250 +# guidance: 3.0 +# state_refresh_rate: 1 +# guidance_interval_min: 0.3 +# guidance_interval_max: 1.0 +# timeshift: 1.0 +# last_step: 0.04 +# scheduler: *scheduler +# w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler +# guidance_fn: src.diffusion.base.guidance.simple_guidance_fn +# step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn + +import torch +import argparse +from omegaconf import OmegaConf +from src.models.vae import fp2uint8 +from src.diffusion.base.guidance import simple_guidance_fn +from src.diffusion.stateful_flow_matching.sharing_sampling import EulerSampler +from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler +from PIL import Image +import gradio as gr +from huggingface_hub import snapshot_download + + +def instantiate_class(config): + kwargs = config.get("init_args", {}) + class_module, class_name = config["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(**kwargs) + +def load_model(weight_dict, denosier): + prefix = "ema_denoiser." + for k, v in denoiser.state_dict().items(): + try: + v.copy_(weight_dict["state_dict"][prefix + k]) + except: + print(f"Failed to copy {prefix + k} to denoiser weight") + return denoiser + + +class Pipeline: + def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution): + self.vae = vae + self.denoiser = denoiser + self.conditioner = conditioner + self.diffusion_sampler = diffusion_sampler + self.resolution = resolution + + @torch.no_grad() + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift): + self.diffusion_sampler.num_steps = num_steps + self.diffusion_sampler.guidance = guidance + self.diffusion_sampler.state_refresh_rate = state_refresh_rate + self.diffusion_sampler.guidance_interval_min = guidance_interval_min + self.diffusion_sampler.guidance_interval_max = guidance_interval_max + self.diffusion_sampler.timeshift = timeshift + generator = torch.Generator(device="cuda").manual_seed(seed) + xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator) + with torch.no_grad(): + condition, uncondition = conditioner([y,]*num_images) + # Sample images: + samples = diffusion_sampler(denoiser, xT, condition, uncondition) + samples = vae.decode(samples) + # fp32 -1,1 -> uint8 0,255 + samples = fp2uint8(samples) + samples = samples.permute(0, 2, 3, 1).cpu().numpy() + images = [] + for i in range(num_images): + image = Image.fromarray(samples[i]) + images.append(image) + return images + +import os +import spaces +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml") + parser.add_argument("--resolution", type=int, default=512) + parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R512") + parser.add_argument("--ckpt_path", type=str, default="models") + args = parser.parse_args() + + if not os.path.exists(args.ckpt_path): + snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path) + + config = OmegaConf.load(args.config) + vae_config = config.model.vae + diffusion_sampler_config = config.model.diffusion_sampler + denoiser_config = config.model.denoiser + conditioner_config = config.model.conditioner + + vae = instantiate_class(vae_config) + denoiser = instantiate_class(denoiser_config) + conditioner = instantiate_class(conditioner_config) + + + diffusion_sampler = EulerSampler( + scheduler=LinearScheduler(), + w_scheduler=LinearScheduler(), + guidance_fn=simple_guidance_fn, + num_steps=50, + guidance=3.0, + state_refresh_rate=1, + guidance_interval_min=0.3, + guidance_interval_max=1.0, + timeshift=1.0 + ) + ckpt_path = os.path.join(args.ckpt_path, "model.ckpt") + ckpt = torch.load(ckpt_path, map_location="cpu") + denoiser = load_model(ckpt, denoiser) + denoiser = denoiser.cuda() + vae = vae.cuda() + denoiser.eval() + + pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution) + + with gr.Blocks() as demo: + gr.Markdown("DDT") + with gr.Row(): + with gr.Column(scale=1): + num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50) + guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0) + num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8) + label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948) + seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0) + state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1) + guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0) + guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0) + timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0) + with gr.Column(scale=2): + btn = gr.Button("Generate") + output = gr.Gallery(label="Images") + + btn.click(fn=pipeline, + inputs=[ + label, + num_images, + seed, + num_steps, + guidance, + state_refresh_rate, + guidance_interval_min, + guidance_interval_max, + timeshift + ], outputs=[output]) + demo.launch(server_name="0.0.0.0", server_port=7861) \ No newline at end of file diff --git a/configs/repa_improved_ddt_xlen22de6_256.yaml b/configs/repa_improved_ddt_xlen22de6_256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d35ea6f5305b655a0e518621bb4229566ec36671 --- /dev/null +++ b/configs/repa_improved_ddt_xlen22de6_256.yaml @@ -0,0 +1,108 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_flatten_condit22_dit6_fixt_xl +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 4000000 + val_check_interval: 4000000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: stabilityai/sd-vae-ft-ema + denoiser: + class_path: src.models.denoiser.decoupled_improved_dit.DDT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_encoder_blocks: 22 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 2.0 + timeshift: 1.0 + state_refresh_rate: 1 + guidance_interval_min: 0.3 + guidance_interval_max: 1.0 + scheduler: *scheduler + w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + last_step: 0.04 + step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 16 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/configs/repa_improved_ddt_xlen22de6_512.yaml b/configs/repa_improved_ddt_xlen22de6_512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..604c53b2c5e2c5df0b107b2937e7347a56c9fb7c --- /dev/null +++ b/configs/repa_improved_ddt_xlen22de6_512.yaml @@ -0,0 +1,108 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 4000000 + val_check_interval: 4000000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: stabilityai/sd-vae-ft-ema + denoiser: + class_path: src.models.denoiser.decoupled_improved_dit.DDT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_encoder_blocks: 22 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 3.0 + state_refresh_rate: 1 + guidance_interval_min: 0.3 + guidance_interval_max: 1.0 + timeshift: 1.0 + last_step: 0.04 + scheduler: *scheduler + w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0 +data: + train_dataset: imagenet512 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 512 + train_batch_size: 16 + eval_max_num_instances: 50000 + pred_batch_size: 32 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 64 + - 64 \ No newline at end of file diff --git a/configs/repa_improved_dit_large.yaml b/configs/repa_improved_dit_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6a51b90dab7aafdef98c3817190220724025e93 --- /dev/null +++ b/configs/repa_improved_dit_large.yaml @@ -0,0 +1,99 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_improved_dit_large +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 400000 + val_check_interval: 100000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.improved_dit.DiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1024 + num_blocks: 24 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 1.00 + scheduler: *scheduler + w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 32 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/configs/repa_improved_dit_xl.yaml b/configs/repa_improved_dit_xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..593c04853bb461593c847b591441d3f876fb960d --- /dev/null +++ b/configs/repa_improved_dit_xl.yaml @@ -0,0 +1,99 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_improved_dit_xlen22de6_512 +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 400000 + val_check_interval: 100000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.improved_dit.DiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 1.00 + scheduler: *scheduler + w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 32 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e1f237c2e2dc64130c0597a078052c270041baca --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +lightning==2.5.0.post0 +omegaconf==2.3.0 +torch==2.3.0 +diffusers==0.30.0 +jsonargparse[signatures]>=4.27.7 +accelerate \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/callbacks/__init__.py b/src/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/callbacks/grad.py b/src/callbacks/grad.py new file mode 100644 index 0000000000000000000000000000000000000000..f9155b68152ac8b4ce2f30e31501ec892175b167 --- /dev/null +++ b/src/callbacks/grad.py @@ -0,0 +1,22 @@ +import torch +import lightning.pytorch as pl +from lightning.pytorch.utilities import grad_norm +from torch.optim import Optimizer + +class GradientMonitor(pl.Callback): + """Logs the gradient norm""" + + def __init__(self, norm_type: int = 2): + norm_type = float(norm_type) + if norm_type <= 0: + raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") + self.norm_type = norm_type + + def on_before_optimizer_step( + self, trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + optimizer: Optimizer + ) -> None: + norms = grad_norm(pl_module, norm_type=self.norm_type) + max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max() + pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]}) \ No newline at end of file diff --git a/src/callbacks/model_checkpoint.py b/src/callbacks/model_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1e9b8e6ff2e35ef37041e9329374bc64ae71fb48 --- /dev/null +++ b/src/callbacks/model_checkpoint.py @@ -0,0 +1,21 @@ +import os.path +from typing import Optional, Dict, Any + +import lightning.pytorch as pl +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from soupsieve.util import lower + + +class CheckpointHook(ModelCheckpoint): + """Save checkpoint with only the incremental part of the model""" + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + self.dirpath = trainer.default_root_dir + self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt") + pl_module.strict_loading = False + + def on_save_checkpoint( + self, trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + checkpoint: Dict[str, Any] + ) -> None: + del checkpoint["callbacks"] \ No newline at end of file diff --git a/src/callbacks/save_images.py b/src/callbacks/save_images.py new file mode 100644 index 0000000000000000000000000000000000000000..c6cd32b5ffd7ee83310dea425abb8039aa80aa32 --- /dev/null +++ b/src/callbacks/save_images.py @@ -0,0 +1,105 @@ +import lightning.pytorch as pl +from lightning.pytorch import Callback + + +import os.path +import numpy +from PIL import Image +from typing import Sequence, Any, Dict +from concurrent.futures import ThreadPoolExecutor + +from lightning.pytorch.utilities.types import STEP_OUTPUT +from lightning_utilities.core.rank_zero import rank_zero_info + +def process_fn(image, path): + Image.fromarray(image).save(path) + +class SaveImagesHook(Callback): + def __init__(self, save_dir="val", max_save_num=0, compressed=True): + self.save_dir = save_dir + self.max_save_num = max_save_num + self.compressed = compressed + + def save_start(self, target_dir): + self.target_dir = target_dir + self.executor_pool = ThreadPoolExecutor(max_workers=8) + if not os.path.exists(self.target_dir): + os.makedirs(self.target_dir, exist_ok=True) + else: + if os.listdir(target_dir) and "debug" not in str(target_dir): + raise FileExistsError(f'{self.target_dir} already exists and not empty!') + self.samples = [] + self._have_saved_num = 0 + rank_zero_info(f"Save images to {self.target_dir}") + + def save_image(self, images, filenames): + images = images.permute(0, 2, 3, 1).cpu().numpy() + for sample, filename in zip(images, filenames): + if isinstance(filename, Sequence): + filename = filename[0] + path = f'{self.target_dir}/{filename}' + if self._have_saved_num >= self.max_save_num: + break + self.executor_pool.submit(process_fn, sample, path) + self._have_saved_num += 1 + + def process_batch( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + samples: STEP_OUTPUT, + batch: Any, + ) -> None: + b, c, h, w = samples.shape + xT, y, metadata = batch + all_samples = pl_module.all_gather(samples).view(-1, c, h, w) + self.save_image(samples, metadata) + if trainer.is_global_zero: + all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy() + self.samples.append(all_samples) + + def save_end(self): + if self.compressed and len(self.samples) > 0: + samples = numpy.concatenate(self.samples) + numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples) + self.executor_pool.shutdown(wait=True) + self.samples = [] + self.target_dir = None + self._have_saved_num = 0 + self.executor_pool = None + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") + self.save_start(target_dir) + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.process_batch(trainer, pl_module, outputs, batch) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.save_end() + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict") + self.save_start(target_dir) + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + samples: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.process_batch(trainer, pl_module, samples, batch) + + def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.save_end() \ No newline at end of file diff --git a/src/callbacks/simple_ema.py b/src/callbacks/simple_ema.py new file mode 100644 index 0000000000000000000000000000000000000000..28bf476b7e87969640490103f039ef15a9d01279 --- /dev/null +++ b/src/callbacks/simple_ema.py @@ -0,0 +1,79 @@ +from typing import Any, Dict + +import torch +import torch.nn as nn +import threading +import lightning.pytorch as pl +from lightning.pytorch import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from src.utils.copy import swap_tensors + +class SimpleEMA(Callback): + def __init__(self, net:nn.Module, ema_net:nn.Module, + decay: float = 0.9999, + every_n_steps: int = 1, + eval_original_model:bool = False + ): + super().__init__() + self.decay = decay + self.every_n_steps = every_n_steps + self.eval_original_model = eval_original_model + self._stream = torch.cuda.Stream() + + self.net_params = list(net.parameters()) + self.ema_params = list(ema_net.parameters()) + + def swap_model(self): + for ema_p, p, in zip(self.ema_params, self.net_params): + swap_tensors(ema_p, p) + + def ema_step(self): + @torch.no_grad() + def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ) + + if self._stream is not None: + self._stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._stream): + ema_update(self.ema_params, self.net_params, self.decay) + + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + if trainer.global_step % self.every_n_steps == 0: + self.ema_step() + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + + def state_dict(self) -> Dict[str, Any]: + return { + "decay": self.decay, + "every_n_steps": self.every_n_steps, + "eval_original_model": self.eval_original_model, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.decay = state_dict["decay"] + self.every_n_steps = state_dict["every_n_steps"] + self.eval_original_model = state_dict["eval_original_model"] + diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1 @@ + diff --git a/src/data/dataset/__init__.py b/src/data/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/dataset/celeba.py b/src/data/dataset/celeba.py new file mode 100644 index 0000000000000000000000000000000000000000..30f5d3f041a6014759a58b9ba4aec5778adf5d10 --- /dev/null +++ b/src/data/dataset/celeba.py @@ -0,0 +1,11 @@ +from typing import Callable +from torchvision.datasets import CelebA + + +class LocalDataset(CelebA): + def __init__(self, root:str, ): + super(LocalDataset, self).__init__(root, "train") + + def __getitem__(self, idx): + data = super().__getitem__(idx) + return data \ No newline at end of file diff --git a/src/data/dataset/imagenet.py b/src/data/dataset/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..59d0547ce65f7b61eca5187fa4740de997989c4c --- /dev/null +++ b/src/data/dataset/imagenet.py @@ -0,0 +1,82 @@ +import torch +from PIL import Image +from torchvision.datasets import ImageFolder +from torchvision.transforms.functional import to_tensor +from torchvision.transforms import Normalize + +from src.data.dataset.metric_dataset import CenterCrop + +class LocalCachedDataset(ImageFolder): + def __init__(self, root, resolution=256): + super().__init__(root) + self.transform = CenterCrop(resolution) + self.cache_root = None + + def load_latent(self, latent_path): + pk_data = torch.load(latent_path) + mean = pk_data['mean'].to(torch.float32) + logvar = pk_data['logvar'].to(torch.float32) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + latent = mean + torch.randn_like(mean) * std + return latent + + def __getitem__(self, idx: int): + image_path, target = self.samples[idx] + latent_path = image_path.replace(self.root, self.cache_root) + ".pt" + + raw_image = Image.open(image_path).convert('RGB') + raw_image = self.transform(raw_image) + raw_image = to_tensor(raw_image) + if self.cache_root is not None: + latent = self.load_latent(latent_path) + else: + latent = raw_image + return raw_image, latent, target + +class ImageNet256(LocalCachedDataset): + def __init__(self, root, ): + super().__init__(root, 256) + self.cache_root = root + "_256_latent" + +class ImageNet512(LocalCachedDataset): + def __init__(self, root, ): + super().__init__(root, 512) + self.cache_root = root + "_512_latent" + +class PixImageNet(ImageFolder): + def __init__(self, root, resolution=256): + super().__init__(root) + self.transform = CenterCrop(resolution) + self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + def __getitem__(self, idx: int): + image_path, target = self.samples[idx] + raw_image = Image.open(image_path).convert('RGB') + raw_image = self.transform(raw_image) + raw_image = to_tensor(raw_image) + + normalized_image = self.normalize(raw_image) + return raw_image, normalized_image, target + +class PixImageNet64(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 64) + +class PixImageNet128(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 128) + + +class PixImageNet256(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 256) + +class PixImageNet512(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 512) + + + + + diff --git a/src/data/dataset/metric_dataset.py b/src/data/dataset/metric_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe7d66281015415012cc0af3f429ad84154c866 --- /dev/null +++ b/src/data/dataset/metric_dataset.py @@ -0,0 +1,82 @@ +import pathlib + +import torch +import random +import numpy as np +from torchvision.io.image import read_image +import torchvision.transforms as tvtf +from torch.utils.data import Dataset + +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + + +from PIL import Image +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +def test_collate(batch): + return torch.stack(batch) + +class ImageDataset(Dataset): + def __init__(self, root, image_size=(224, 224)): + self.root = pathlib.Path(root) + images = [] + for ext in IMG_EXTENSIONS: + images.extend(self.root.rglob(ext)) + random.shuffle(images) + self.images = list(map(lambda x: str(x), images)) + self.transform = tvtf.Compose( + [ + CenterCrop(image_size[0]), + tvtf.ToTensor(), + tvtf.Lambda(lambda x: (x*255).to(torch.uint8)), + tvtf.Lambda(lambda x: x.expand(3, -1, -1)) + ] + ) + self.size = image_size + + def __getitem__(self, idx): + try: + image = Image.open(self.images[idx]) + image = self.transform(image) + except Exception as e: + print(self.images[idx]) + image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8) + + # print(image) + metadata = dict( + path = self.images[idx], + root = self.root, + ) + return image #, metadata + + def __len__(self): + return len(self.images) \ No newline at end of file diff --git a/src/data/dataset/randn.py b/src/data/dataset/randn.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ec7727261dce9db0f4abcedefa9a03cbe8ab60 --- /dev/null +++ b/src/data/dataset/randn.py @@ -0,0 +1,41 @@ +import os.path +import random + +import torch +from torch.utils.data import Dataset + + + +class RandomNDataset(Dataset): + def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ): + self.selected_classes = selected_classes + if selected_classes is not None: + num_classes = len(selected_classes) + max_num_instances = 10*num_classes + self.num_classes = num_classes + self.seeds = seeds + if seeds is not None: + self.max_num_instances = len(seeds)*num_classes + self.num_seeds = len(seeds) + else: + self.num_seeds = (max_num_instances + num_classes - 1) // num_classes + self.max_num_instances = self.num_seeds*num_classes + + self.latent_shape = latent_shape + + + def __getitem__(self, idx): + label = idx // self.num_seeds + if self.selected_classes: + label = self.selected_classes[label] + seed = random.randint(0, 1<<31) #idx % self.num_seeds + if self.seeds is not None: + seed = self.seeds[idx % self.num_seeds] + + # cls_dir = os.path.join(self.root, f"{label}") + filename = f"{label}_{seed}.png", + generator = torch.Generator().manual_seed(seed) + latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) + return latent, label, filename + def __len__(self): + return self.max_num_instances \ No newline at end of file diff --git a/src/data/var_training.py b/src/data/var_training.py new file mode 100644 index 0000000000000000000000000000000000000000..de7fb740e1a9639b20a804507c2cfea65bfaa598 --- /dev/null +++ b/src/data/var_training.py @@ -0,0 +1,145 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +import concurrent.futures +from concurrent.futures import ProcessPoolExecutor +from typing import List +from PIL import Image +import torch +import random +import numpy as np +import copy +import torchvision.transforms.functional as tvtf +from src.models.vae import uint82fp + + +def center_crop_arr(pil_image, width, height): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = max(width / pil_image.size[0], height / pil_image.size[1]) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + arr = np.array(pil_image) + crop_y = random.randint(0, (arr.shape[0] - height)) + crop_x = random.randint(0, (arr.shape[1] - width)) + return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width]) + +def process_fn(width, height, data, hflip=0.5): + image, label = data + if random.uniform(0, 1) > hflip: # hflip + image = tvtf.hflip(image) + image = center_crop_arr(image, width, height) # crop + image = np.array(image).transpose(2, 0, 1) + return image, label + +class VARCandidate: + def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024): + self.aspect_ratio = aspect_ratio + self.width = int(width) + self.height = int(height) + self.buffer = buffer + self.max_buffer_size = max_buffer_size + + def add_sample(self, data): + self.buffer.append(data) + self.buffer = self.buffer[-self.max_buffer_size:] + + def ready(self, batch_size): + return len(self.buffer) >= batch_size + + def get_batch(self, batch_size): + batch = self.buffer[:batch_size] + self.buffer = self.buffer[batch_size:] + batch = [copy.deepcopy(b.result()) for b in batch] + x, y = zip(*batch) + x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0) + x = list(map(uint82fp, x)) + return x, y + +class VARTransformEngine: + def __init__(self, + base_image_size, + num_aspect_ratios, + min_aspect_ratio, + max_aspect_ratio, + num_workers = 8, + ): + self.base_image_size = base_image_size + self.num_aspect_ratios = num_aspect_ratios + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios) + self.aspect_ratios = self.aspect_ratios.tolist() + self.candidates_pool = [] + for i in range(self.num_aspect_ratios): + candidate = VARCandidate( + aspect_ratio=self.aspect_ratios[i], + width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16), + height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16), + buffer=[], + max_buffer_size=1024 + ) + self.candidates_pool.append(candidate) + self.default_candidate = VARCandidate( + aspect_ratio=1.0, + width=self.base_image_size, + height=self.base_image_size, + buffer=[], + max_buffer_size=1024, + ) + self.executor_pool = ProcessPoolExecutor(max_workers=num_workers) + self._prefill_count = 100 + + def find_candidate(self, data): + image = data[0] + aspect_ratio = image.size[0] / image.size[1] + min_distance = 1000000 + min_candidate = None + for candidate in self.candidates_pool: + dis = abs(aspect_ratio - candidate.aspect_ratio) + if dis < min_distance: + min_distance = dis + min_candidate = candidate + return min_candidate + + + def __call__(self, batch_data): + self._prefill_count -= 1 + if isinstance(batch_data[0], torch.Tensor): + batch_data[0] = batch_data[0].unbind(0) + + batch_data = list(zip(*batch_data)) + for data in batch_data: + candidate = self.find_candidate(data) + future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data) + candidate.add_sample(future) + if self._prefill_count >= 0: + future = self.executor_pool.submit(process_fn, + self.default_candidate.width, + self.default_candidate.height, + data) + self.default_candidate.add_sample(future) + + batch_size = len(batch_data) + random.shuffle(self.candidates_pool) + for candidate in self.candidates_pool: + if candidate.ready(batch_size=batch_size): + return candidate.get_batch(batch_size=batch_size) + + # fallback to default 256 + for data in batch_data: + future = self.executor_pool.submit(process_fn, + self.default_candidate.width, + self.default_candidate.height, + data) + self.default_candidate.add_sample(future) + return self.default_candidate.get_batch(batch_size=batch_size) \ No newline at end of file diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/diffusion/base/guidance.py b/src/diffusion/base/guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..07b4754a30c2dc3ac65ef2237d0175d6915dc0ec --- /dev/null +++ b/src/diffusion/base/guidance.py @@ -0,0 +1,60 @@ +import torch + +def simple_guidance_fn(out, cfg): + uncondition, condtion = out.chunk(2, dim=0) + out = uncondition + cfg * (condtion - uncondition) + return out + +def c3_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condtion = out.chunk(2, dim=0) + out = condtion + out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3]) + return out + +def c4_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p05_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p10_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p15_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p20_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def p4_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condtion = out.chunk(2, dim=0) + out = condtion + out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:]) + return out diff --git a/src/diffusion/base/sampling.py b/src/diffusion/base/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f9776f0b10d3ac7c6fc1d5fe30770d4c7bf13e --- /dev/null +++ b/src/diffusion/base/sampling.py @@ -0,0 +1,31 @@ +from typing import Union, List + +import torch +import torch.nn as nn +from typing import Callable +from src.diffusion.base.scheduling import BaseScheduler + +class BaseSampler(nn.Module): + def __init__(self, + scheduler: BaseScheduler = None, + guidance_fn: Callable = None, + num_steps: int = 250, + guidance: Union[float, List[float]] = 1.0, + *args, + **kwargs + ): + super(BaseSampler, self).__init__() + self.num_steps = num_steps + self.guidance = guidance + self.guidance_fn = guidance_fn + self.scheduler = scheduler + + + def _impl_sampling(self, net, noise, condition, uncondition): + raise NotImplementedError + + def __call__(self, net, noise, condition, uncondition): + denoised = self._impl_sampling(net, noise, condition, uncondition) + return denoised + + diff --git a/src/diffusion/base/scheduling.py b/src/diffusion/base/scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..05c7fb18156e2e8aa28121e9ac855ba6ccf698f6 --- /dev/null +++ b/src/diffusion/base/scheduling.py @@ -0,0 +1,32 @@ +import torch +from torch import Tensor + +class BaseScheduler: + def alpha(self, t) -> Tensor: + ... + def sigma(self, t) -> Tensor: + ... + + def dalpha(self, t) -> Tensor: + ... + def dsigma(self, t) -> Tensor: + ... + + def dalpha_over_alpha(self, t) -> Tensor: + return self.dalpha(t) / self.alpha(t) + + def dsigma_mul_sigma(self, t) -> Tensor: + return self.dsigma(t)*self.sigma(t) + + def drift_coefficient(self, t): + alpha, sigma = self.alpha(t), self.sigma(t) + dalpha, dsigma = self.dalpha(t), self.dsigma(t) + return dalpha/(alpha + 1e-6) + + def diffuse_coefficient(self, t): + alpha, sigma = self.alpha(t), self.sigma(t) + dalpha, dsigma = self.dalpha(t), self.dsigma(t) + return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2 + + def w(self, t): + return self.sigma(t) diff --git a/src/diffusion/base/training.py b/src/diffusion/base/training.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6d0e0f5baaacf3e9474dc3ad8df28745736728 --- /dev/null +++ b/src/diffusion/base/training.py @@ -0,0 +1,29 @@ +import time + +import torch +import torch.nn as nn + +class BaseTrainer(nn.Module): + def __init__(self, + null_condition_p=0.1, + log_var=False, + ): + super(BaseTrainer, self).__init__() + self.null_condition_p = null_condition_p + self.log_var = log_var + + def preproprocess(self, raw_iamges, x, condition, uncondition): + bsz = x.shape[0] + if self.null_condition_p > 0: + mask = torch.rand((bsz), device=condition.device) < self.null_condition_p + mask = mask.expand_as(condition) + condition[mask] = uncondition[mask] + return raw_iamges, x, condition + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + raise NotImplementedError + + def __call__(self, net, ema_net, raw_images, x, condition, uncondition): + raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition) + return self._impl_trainstep(net, ema_net, raw_images, x, condition) + diff --git a/src/diffusion/ddpm/ddim_sampling.py b/src/diffusion/ddpm/ddim_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..0db2a1d955b5f63ef18e09624e168c691a66efd0 --- /dev/null +++ b/src/diffusion/ddpm/ddim_sampling.py @@ -0,0 +1,40 @@ +import torch +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + +import logging +logger = logging.getLogger(__name__) + +class DDIMSampler(BaseSampler): + def __init__( + self, + train_num_steps=1000, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.train_num_steps = train_num_steps + assert self.scheduler is not None + + def _impl_sampling(self, net, noise, condition, uncondition): + batch_size = noise.shape[0] + steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device) + steps = torch.flip(steps, dims=[0]) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + t_cur = t_cur.repeat(batch_size) + t_next = t_next.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + alpha = self.scheduler.alpha(t_cur) + sigma_next = self.scheduler.sigma(t_next) + alpha_next = self.scheduler.alpha(t_next) + cfg_x = torch.cat([x, x], dim=0) + t = t_cur.repeat(2) + out = net(cfg_x, t, cfg_condition) + out = self.guidance_fn(out, self.guidance) + x0 = (x - sigma * out) / alpha + x = alpha_next * x0 + sigma_next * out + return x0 \ No newline at end of file diff --git a/src/diffusion/ddpm/scheduling.py b/src/diffusion/ddpm/scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..aff1523b768b9ea83fcb5984e2190a100d5d0922 --- /dev/null +++ b/src/diffusion/ddpm/scheduling.py @@ -0,0 +1,102 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class DDPMScheduler(BaseScheduler): + def __init__( + self, + beta_min=0.0001, + beta_max=0.02, + num_steps=1000, + ): + super().__init__() + self.beta_min = beta_min + self.beta_max = beta_max + self.num_steps = num_steps + + self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda") + self.alphas_table = torch.cumprod(1-self.betas_table, dim=0) + self.sigmas_table = 1-self.alphas_table + + + def beta(self, t) -> Tensor: + t = t.to(torch.long) + return self.betas_table[t].view(-1, 1, 1, 1) + + def alpha(self, t) -> Tensor: + t = t.to(torch.long) + return self.alphas_table[t].view(-1, 1, 1, 1)**0.5 + + def sigma(self, t) -> Tensor: + t = t.to(torch.long) + return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5 + + def dsigma(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def dalpha_over_alpha(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dsigma_mul_sigma(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dalpha(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def drift_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def diffuse_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def w(self, t): + raise NotImplementedError("wrong usage") + + +class VPScheduler(BaseScheduler): + def __init__( + self, + beta_min=0.1, + beta_max=20, + ): + super().__init__() + self.beta_min = beta_min + self.beta_d = beta_max - beta_min + def beta(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1) + + def sigma(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t + return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1) + + def dsigma(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def dalpha_over_alpha(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dsigma_mul_sigma(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dalpha(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def alpha(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t + return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1) + + def drift_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def diffuse_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def w(self, t): + return self.diffuse_coefficient(t) + + + diff --git a/src/diffusion/ddpm/training.py b/src/diffusion/ddpm/training.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0d0ecef6d3bc2e845a3a2a8277e373a5ae2230 --- /dev/null +++ b/src/diffusion/ddpm/training.py @@ -0,0 +1,83 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class VPTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + train_max_t=1000, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.train_max_t = train_max_t + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + x_t = alpha * x + noise * sigma + out = net(x_t, t*self.train_max_t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - noise)**2 + + out = dict( + loss=loss.mean(), + ) + return out + + +class DDPMTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn: Callable = constant, + train_max_t=1000, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.train_max_t = train_max_t + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + t = torch.randint(0, self.train_max_t, (batch_size,)) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + x_t = alpha * x + noise * sigma + out = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight * (out - noise) ** 2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/ddpm/vp_sampling.py b/src/diffusion/ddpm/vp_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..250b32d6194f77bc7daede59a9c91fb53d53351a --- /dev/null +++ b/src/diffusion/ddpm/vp_sampling.py @@ -0,0 +1,59 @@ +import torch + +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * +from typing import Callable + +def ode_step_fn(x, eps, beta, sigma, dt): + return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt + +def sde_step_fn(x, eps, beta, sigma, dt): + return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x) + +import logging +logger = logging.getLogger(__name__) + +class VPEulerSampler(BaseSampler): + def __init__( + self, + train_max_t=1000, + guidance_fn: Callable = None, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.guidance_fn = guidance_fn + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.train_max_t = train_max_t + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + assert self.last_step > 0.0 + assert self.scheduler is not None + + def _impl_sampling(self, net, noise, condition, uncondition): + batch_size = noise.shape[0] + steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device) + steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + beta = self.scheduler.beta(t_cur) + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition) + eps = self.guidance_fn(out, self.guidance) + if i < self.num_steps -1 : + x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0]) + x = self.step_fn(x, eps, beta, sigma, dt) + else: + x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step) + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..15d0c78c945dd0b82005f34e8ff30b5f301da454 --- /dev/null +++ b/src/diffusion/flow_matching/adam_sampling.py @@ -0,0 +1,107 @@ +import math +from src.diffusion.base.sampling import * +from src.diffusion.base.scheduling import * +from src.diffusion.pre_integral import * + +from typing import Callable, List, Tuple + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def t2snr(t): + if isinstance(t, torch.Tensor): + return (t.clip(min=1e-8)/(1-t + 1e-8)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2snr(t) for t in t] + t = max(t, 1e-8) + return (t/(1-t + 1e-8)) + +def t2logsnr(t): + if isinstance(t, torch.Tensor): + return torch.log(t.clip(min=1e-3)/(1-t + 1e-3)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2logsnr(t) for t in t] + t = max(t, 1e-3) + return math.log(t/(1-t + 1e-3)) + +def t2isnr(t): + return 1/t2snr(t) + +def nop(t): + return t + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +import logging +logger = logging.getLogger(__name__) + +class AdamLMSampler(BaseSampler): + def __init__( + self, + order: int = 2, + timeshift: float = 1.0, + lms_transform_fn: Callable = nop, + w_scheduler: BaseScheduler = None, + step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.w_scheduler = w_scheduler + + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + self.order = order + self.lms_transform_fn = lms_transform_fn + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, timeshift) + self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self._reparameterize_coeffs() + + def _reparameterize_coeffs(self): + solver_coeffs = [[] for _ in range(self.num_steps)] + for i in range(0, self.num_steps): + pre_vs = [1.0, ]*(i+1) + pre_ts = self.lms_transform_fn(self.timesteps[:i+1]) + int_t_start = self.lms_transform_fn(self.timesteps[i]) + int_t_end = self.lms_transform_fn(self.timesteps[i+1]) + + order_annealing = self.order #self.num_steps - i + order = min(self.order, i + 1, order_annealing) + + _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end) + solver_coeffs[i] = coeffs + self.solver_coeffs = solver_coeffs + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + pred_trajectory = [] + t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype) + timedeltas = self.timedeltas + solver_coeffs = self.solver_coeffs + for i in range(self.num_steps): + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t, cfg_condition) + out = self.guidance_fn(out, self.guidances[i]) + pred_trajectory.append(out) + out = torch.zeros_like(out) + order = len(self.solver_coeffs[i]) + for j in range(order): + out += solver_coeffs[i][j] * pred_trajectory[-order:][j] + v = out + dt = timedeltas[i] + x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0) + x = self.step_fn(x, v, dt, s=0, w=0) + t_cur += dt + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/sampling.py b/src/diffusion/flow_matching/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..62bdd8bdffbda884be3f0dc35fc0bbcd2c1df18b --- /dev/null +++ b/src/diffusion/flow_matching/sampling.py @@ -0,0 +1,179 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def sde_mean_step_fn(x, v, dt, s, w): + return x + v * dt + s * w * dt + +def sde_step_fn(x, v, dt, s, w): + return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x) + +def sde_preserve_step_fn(x, v, dt, s, w): + return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x) + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v = out + s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma) + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x + + +class HeunSampler(BaseSampler): + def __init__( + self, + scheduler: BaseScheduler = None, + w_scheduler: BaseScheduler = None, + exact_henu=False, + timeshift=1.0, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.scheduler = scheduler + self.exact_henu = exact_henu + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Henu sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + v_hat, s_hat = 0.0, 0.0 + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + t_hat = t_next + t_hat = t_hat.repeat(batch_size) + sigma_hat = self.scheduler.sigma(t_hat) + alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat) + dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat) + + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + if i == 0 or self.exact_henu: + cfg_x = torch.cat([x, x], dim=0) + cfg_t_cur = t_cur.repeat(2) + out = net(cfg_x, cfg_t_cur, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v = out + s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma) + else: + v = v_hat + s = s_hat + x_hat = self.step_fn(x, v, dt, s=s, w=w) + # henu correct + if i < self.num_steps -1: + cfg_x_hat = torch.cat([x_hat, x_hat], dim=0) + cfg_t_hat = t_hat.repeat(2) + out = net(cfg_x_hat, cfg_t_hat, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v_hat = out + s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat) + v = (v + v_hat) / 2 + s = (s + s_hat) / 2 + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/scheduling.py b/src/diffusion/flow_matching/scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..a82cd3a2fcb5e3080710fa0208c5aafff54cd068 --- /dev/null +++ b/src/diffusion/flow_matching/scheduling.py @@ -0,0 +1,39 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class LinearScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return (t).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return (1-t).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return torch.full_like(t, 1.0).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.full_like(t, -1.0).view(-1, 1, 1, 1) + +# SoTA for ImageNet! +class GVPScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def w(self, t): + return torch.sin(t)**2 + +class ConstScheduler(BaseScheduler): + def w(self, t): + return torch.ones(1, 1, 1, 1).to(t.device, t.dtype) + +from src.diffusion.ddpm.scheduling import VPScheduler +class VPBetaScheduler(VPScheduler): + def w(self, t): + return self.beta(t).view(-1, 1, 1, 1) + + + diff --git a/src/diffusion/flow_matching/training.py b/src/diffusion/flow_matching/training.py new file mode 100644 index 0000000000000000000000000000000000000000..55c964d97776611769c8d15ef39180e271671afe --- /dev/null +++ b/src/diffusion/flow_matching/training.py @@ -0,0 +1,55 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class FlowMatchingTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + loss = weight*(out - v_t)**2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/flow_matching/training_cos.py b/src/diffusion/flow_matching/training_cos.py new file mode 100644 index 0000000000000000000000000000000000000000..aff30a720207d661e8aea4ae84e900ec31470cb6 --- /dev/null +++ b/src/diffusion/flow_matching/training_cos.py @@ -0,0 +1,59 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class COSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + fm_loss = weight*(out - v_t)**2 + cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1) + cos_loss = 1 - cos_sim + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + cos_loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/flow_matching/training_repa.py b/src/diffusion/flow_matching/training_repa.py new file mode 100644 index 0000000000000000000000000000000000000000..40d80a60c4d2c47d0b347100f9e41a7589177adc --- /dev/null +++ b/src/diffusion/flow_matching/training_repa.py @@ -0,0 +1,137 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/pre_integral.py b/src/diffusion/pre_integral.py new file mode 100644 index 0000000000000000000000000000000000000000..848533a8e1aa99b4f2249560d4e2cec550f7852c --- /dev/null +++ b/src/diffusion/pre_integral.py @@ -0,0 +1,143 @@ +import torch + +# lagrange interpolation +def lagrange_preint_o1(t1, v1, int_t_start, int_t_end): + ''' + lagrange interpolation of order 1 + Args: + t1: timestepx + v1: value field at t1 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1 = (int_t_end-int_t_start) + return int1*v1, (int1/int1, ) + +def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end): + ''' + lagrange interpolation of order 2 + Args: + t1: timestepx + t2: timestepy + v1: value field at t1 + v2: value field at t2 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2) + int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2) + int_sum = int1+int2 + return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum) + +def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end): + ''' + lagrange interpolation of order 3 + Args: + t1: timestepx + t2: timestepy + t3: timestepz + v1: value field at t1 + v2: value field at t2 + v3: value field at t3 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1_denom = (t1-t2)*(t1-t3) + int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end + int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start + int1 = (int1_end - int1_start)/int1_denom + int2_denom = (t2-t1)*(t2-t3) + int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end + int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start + int2 = (int2_end - int2_start)/int2_denom + int3_denom = (t3-t1)*(t3-t2) + int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end + int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start + int3 = (int3_end - int3_start)/int3_denom + int_sum = int1+int2+int3 + return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum) + +def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end): + ''' + lagrange interpolation of order 4 + Args: + t1: timestepx + t2: timestepy + t3: timestepz + t4: timestepw + v1: value field at t1 + v2: value field at t2 + v3: value field at t3 + v4: value field at t4 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1_denom = (t1-t2)*(t1-t3)*(t1-t4) + int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end + int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start + int1 = (int1_end - int1_start)/int1_denom + int2_denom = (t2-t1)*(t2-t3)*(t2-t4) + int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end + int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start + int2 = (int2_end - int2_start)/int2_denom + int3_denom = (t3-t1)*(t3-t2)*(t3-t4) + int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end + int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start + int3 = (int3_end - int3_start)/int3_denom + int4_denom = (t4-t1)*(t4-t2)*(t4-t3) + int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end + int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start + int4 = (int4_end - int4_start)/int4_denom + int_sum = int1+int2+int3+int4 + return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum) + + +def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end): + ''' + lagrange interpolation + Args: + order: order of interpolation + pre_vs: value field at pre_ts + pre_ts: timesteps + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + order = min(order, len(pre_vs), len(pre_ts)) + if order == 1: + return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end) + elif order == 2: + return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + elif order == 3: + return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + elif order == 4: + return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + else: + raise ValueError('Invalid order') + + +def polynomial_integral(coeffs, int_t_start, int_t_end): + ''' + polynomial integral + Args: + coeffs: coefficients of the polynomial + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + orders = len(coeffs) + int_val = 0 + for o in range(orders): + int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1)) + return int_val + diff --git a/src/diffusion/stateful_flow_matching/adam_sampling.py b/src/diffusion/stateful_flow_matching/adam_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2e95b7bbf40ba7efe2bfb5f4a4ba6f5a3881a9 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/adam_sampling.py @@ -0,0 +1,112 @@ +import math +from src.diffusion.base.sampling import * +from src.diffusion.base.scheduling import * +from src.diffusion.pre_integral import * + +from typing import Callable, List, Tuple + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def t2snr(t): + if isinstance(t, torch.Tensor): + return (t.clip(min=1e-8)/(1-t + 1e-8)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2snr(t) for t in t] + t = max(t, 1e-8) + return (t/(1-t + 1e-8)) + +def t2logsnr(t): + if isinstance(t, torch.Tensor): + return torch.log(t.clip(min=1e-3)/(1-t + 1e-3)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2logsnr(t) for t in t] + t = max(t, 1e-3) + return math.log(t/(1-t + 1e-3)) + +def t2isnr(t): + return 1/t2snr(t) + +def nop(t): + return t + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +import logging +logger = logging.getLogger(__name__) + +class AdamLMSampler(BaseSampler): + def __init__( + self, + order: int = 2, + timeshift: float = 1.0, + state_refresh_rate: int = 1, + lms_transform_fn: Callable = nop, + w_scheduler: BaseScheduler = None, + step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.w_scheduler = w_scheduler + self.state_refresh_rate = state_refresh_rate + + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + self.order = order + self.lms_transform_fn = lms_transform_fn + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, timeshift) + self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self._reparameterize_coeffs() + + def _reparameterize_coeffs(self): + solver_coeffs = [[] for _ in range(self.num_steps)] + for i in range(0, self.num_steps): + pre_vs = [1.0, ]*(i+1) + pre_ts = self.lms_transform_fn(self.timesteps[:i+1]) + int_t_start = self.lms_transform_fn(self.timesteps[i]) + int_t_end = self.lms_transform_fn(self.timesteps[i+1]) + + order_annealing = self.order #self.num_steps - i + order = min(self.order, i + 1, order_annealing) + + _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end) + solver_coeffs[i] = coeffs + self.solver_coeffs = solver_coeffs + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + state = None + pred_trajectory = [] + t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype) + timedeltas = self.timedeltas + solver_coeffs = self.solver_coeffs + for i in range(self.num_steps): + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i % self.state_refresh_rate == 0: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + out = self.guidance_fn(out, self.guidances[i]) + pred_trajectory.append(out) + out = torch.zeros_like(out) + order = len(self.solver_coeffs[i]) + for j in range(order): + out += solver_coeffs[i][j] * pred_trajectory[-order:][j] + v = out + dt = timedeltas[i] + x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0) + x = self.step_fn(x, v, dt, s=0, w=0) + t_cur += dt + return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/sampling.py b/src/diffusion/stateful_flow_matching/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..5fdfdb222c0c65f7a5e4c64121d809a1449fe6ae --- /dev/null +++ b/src/diffusion/stateful_flow_matching/sampling.py @@ -0,0 +1,103 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def sde_mean_step_fn(x, v, dt, s, w): + return x + v * dt + s * w * dt + +def sde_step_fn(x, v, dt, s, w): + return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x) + +def sde_preserve_step_fn(x, v, dt, s, w): + return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x) + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + state_refresh_rate=1, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + self.state_refresh_rate = state_refresh_rate + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + state = None + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i % self.state_refresh_rate == 0: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + out = self.guidance_fn(out, self.guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma) + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/scheduling.py b/src/diffusion/stateful_flow_matching/scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..a82cd3a2fcb5e3080710fa0208c5aafff54cd068 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/scheduling.py @@ -0,0 +1,39 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class LinearScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return (t).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return (1-t).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return torch.full_like(t, 1.0).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.full_like(t, -1.0).view(-1, 1, 1, 1) + +# SoTA for ImageNet! +class GVPScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def w(self, t): + return torch.sin(t)**2 + +class ConstScheduler(BaseScheduler): + def w(self, t): + return torch.ones(1, 1, 1, 1).to(t.device, t.dtype) + +from src.diffusion.ddpm.scheduling import VPScheduler +class VPBetaScheduler(VPScheduler): + def w(self, t): + return self.beta(t).view(-1, 1, 1, 1) + + + diff --git a/src/diffusion/stateful_flow_matching/sharing_sampling.py b/src/diffusion/stateful_flow_matching/sharing_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..f372028da2f921ea6c943221691e2f3fb17cfa26 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/sharing_sampling.py @@ -0,0 +1,149 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + state_refresh_rate=1, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + self.state_refresh_rate = state_refresh_rate + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + # init recompute + self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate) + self.recompute_timesteps = list(range(self.num_steps)) + + def sharing_dp(self, net, noise, condition, uncondition): + _, C, H, W = noise.shape + B = 8 + template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device) + template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device) + template_uncondition = torch.full((B, ), 1000, device=condition.device) + _, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition) + states = torch.stack(state_list) + N, B, L, C = states.shape + states = states.view(N, B*L, C ) + states = states.permute(1, 0, 2) + states = torch.nn.functional.normalize(states, dim=-1) + with torch.autocast(device_type="cuda", dtype=torch.float64): + sim = torch.bmm(states, states.transpose(1, 2)) + sim = torch.mean(sim, dim=0).cpu() + error_map = (1-sim).tolist() + + # init cum-error + for i in range(1, self.num_steps): + for j in range(0, i): + error_map[i][j] = error_map[i-1][j] + error_map[i][j] + + # init dp and force 0 start + C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)] + P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)] + for i in range(1, self.num_steps+1): + C[1][i] = error_map[i - 1][0] + P[1][i] = 0 + + # dp state + for step in range(2, self.num_recompute_timesteps+1): + for i in range(step, self.num_steps+1): + min_value = 99999 + min_index = -1 + for j in range(step-1, i): + value = C[step-1][j] + error_map[i-1][j] + if value < min_value: + min_value = value + min_index = j + C[step][i] = min_value + P[step][i] = min_index + + # trace back + timesteps = [self.num_steps,] + for i in range(self.num_recompute_timesteps, 0, -1): + idx = timesteps[-1] + timesteps.append(P[i][idx]) + timesteps.reverse() + + print("recompute timesteps solved by DP: ", timesteps) + return timesteps[:-1] + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + state = None + pooled_state_list = [] + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i in self.recompute_timesteps: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + out = self.guidance_fn(out, self.guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=0.0, w=0.0) + else: + x = self.last_step_fn(x, v, dt, s=0.0, w=0.0) + pooled_state_list.append(state) + return x, pooled_state_list + + def __call__(self, net, noise, condition, uncondition): + if len(self.recompute_timesteps) != self.num_recompute_timesteps: + self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition) + denoised, _ = self._impl_sampling(net, noise, condition, uncondition) + return denoised \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training.py b/src/diffusion/stateful_flow_matching/training.py new file mode 100644 index 0000000000000000000000000000000000000000..4c49e1e73ac5dff25b0208d93040fa8b633460b2 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training.py @@ -0,0 +1,55 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class FlowMatchingTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out, _ = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + loss = weight*(out - v_t)**2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_repa.py b/src/diffusion/stateful_flow_matching/training_repa.py new file mode 100644 index 0000000000000000000000000000000000000000..4846d5d4fef0aacf5823462704fffc0665b9915e --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_repa.py @@ -0,0 +1,152 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + self.proj_encoder_dim = proj_encoder_dim + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + + if getattr(net, "blocks", None) is not None: + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + else: + handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + if dst_feature.shape[1] != src_feature.shape[1]: + dst_length = dst_feature.shape[1] + rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 + dst_height = (dst_length)**0.5 * (height/width)**0.5 + dst_width = (dst_length)**0.5 * (width/height)**0.5 + dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) + dst_feature = dst_feature.permute(0, 3, 1, 2) + dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) + dst_feature = dst_feature.permute(0, 2, 3, 1) + dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/lightning_data.py b/src/lightning_data.py new file mode 100644 index 0000000000000000000000000000000000000000..9f75a420a336d909203bbe802eea5f6894ccea34 --- /dev/null +++ b/src/lightning_data.py @@ -0,0 +1,162 @@ +from typing import Any +import torch +import copy +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS +from torch.utils.data import DataLoader +from src.data.dataset.randn import RandomNDataset +from src.data.var_training import VARTransformEngine + +def collate_fn(batch): + new_batch = copy.deepcopy(batch) + new_batch = list(zip(*new_batch)) + for i in range(len(new_batch)): + if isinstance(new_batch[i][0], torch.Tensor): + try: + new_batch[i] = torch.stack(new_batch[i], dim=0) + except: + print("Warning: could not stack tensors") + return new_batch + +class DataModule(pl.LightningDataModule): + def __init__(self, + train_root, + test_nature_root, + test_gen_root, + train_image_size=64, + train_batch_size=64, + train_num_workers=8, + var_transform_engine: VARTransformEngine = None, + train_prefetch_factor=2, + train_dataset: str = None, + eval_batch_size=32, + eval_num_workers=4, + eval_max_num_instances=50000, + pred_batch_size=32, + pred_num_workers=4, + pred_seeds:str=None, + pred_selected_classes=None, + num_classes=1000, + latent_shape=(4,64,64), + ): + super().__init__() + pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None + + self.train_root = train_root + self.train_image_size = train_image_size + self.train_dataset = train_dataset + # stupid data_convert override, just to make nebular happy + self.train_batch_size = train_batch_size + self.train_num_workers = train_num_workers + self.train_prefetch_factor = train_prefetch_factor + + self.test_nature_root = test_nature_root + self.test_gen_root = test_gen_root + self.eval_max_num_instances = eval_max_num_instances + self.pred_seeds = pred_seeds + self.num_classes = num_classes + self.latent_shape = latent_shape + + self.eval_batch_size = eval_batch_size + self.pred_batch_size = pred_batch_size + + self.pred_num_workers = pred_num_workers + self.eval_num_workers = eval_num_workers + + self.pred_selected_classes = pred_selected_classes + + self._train_dataloader = None + self.var_transform_engine = var_transform_engine + + def setup(self, stage: str) -> None: + if stage == "fit": + assert self.train_dataset is not None + if self.train_dataset == "pix_imagenet64": + from src.data.dataset.imagenet import PixImageNet64 + self.train_dataset = PixImageNet64( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet128": + from src.data.dataset.imagenet import PixImageNet128 + self.train_dataset = PixImageNet128( + root=self.train_root, + ) + elif self.train_dataset == "imagenet256": + from src.data.dataset.imagenet import ImageNet256 + self.train_dataset = ImageNet256( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet256": + from src.data.dataset.imagenet import PixImageNet256 + self.train_dataset = PixImageNet256( + root=self.train_root, + ) + elif self.train_dataset == "imagenet512": + from src.data.dataset.imagenet import ImageNet512 + self.train_dataset = ImageNet512( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet512": + from src.data.dataset.imagenet import PixImageNet512 + self.train_dataset = PixImageNet512( + root=self.train_root, + ) + else: + raise NotImplementedError("no such dataset") + + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + if self.var_transform_engine and self.trainer.training: + batch = self.var_transform_engine(batch) + return batch + + def train_dataloader(self) -> TRAIN_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True) + self._train_dataloader = DataLoader( + self.train_dataset, + self.train_batch_size, + timeout=6000, + num_workers=self.train_num_workers, + prefetch_factor=self.train_prefetch_factor, + sampler=sampler, + collate_fn=collate_fn, + ) + return self._train_dataloader + + def val_dataloader(self) -> EVAL_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + self.eval_dataset = RandomNDataset( + latent_shape=self.latent_shape, + num_classes=self.num_classes, + max_num_instances=self.eval_max_num_instances, + ) + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) + return DataLoader(self.eval_dataset, self.eval_batch_size, + num_workers=self.eval_num_workers, + prefetch_factor=2, + collate_fn=collate_fn, + sampler=sampler + ) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + self.pred_dataset = RandomNDataset( + seeds= self.pred_seeds, + max_num_instances=50000, + num_classes=self.num_classes, + selected_classes=self.pred_selected_classes, + latent_shape=self.latent_shape, + ) + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) + return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size, + num_workers=self.pred_num_workers, + prefetch_factor=4, + collate_fn=collate_fn, + sampler=sampler + ) diff --git a/src/lightning_model.py b/src/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4602e8217f593da55edcfff889b05e0d5992fd2d --- /dev/null +++ b/src/lightning_model.py @@ -0,0 +1,123 @@ +from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict +import os.path +import copy +import torch +import torch.nn as nn +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT +from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from lightning.pytorch.callbacks import Callback + + +from src.models.vae import BaseVAE, fp2uint8 +from src.models.conditioner import BaseConditioner +from src.utils.model_loader import ModelLoader +from src.callbacks.simple_ema import SimpleEMA +from src.diffusion.base.sampling import BaseSampler +from src.diffusion.base.training import BaseTrainer +from src.utils.no_grad import no_grad, filter_nograd_tensors +from src.utils.copy import copy_params + +EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA] +OptimizerCallable = Callable[[Iterable], Optimizer] +LRSchedulerCallable = Callable[[Optimizer], LRScheduler] + + +class LightningModel(pl.LightningModule): + def __init__(self, + vae: BaseVAE, + conditioner: BaseConditioner, + denoiser: nn.Module, + diffusion_trainer: BaseTrainer, + diffusion_sampler: BaseSampler, + ema_tracker: Optional[EMACallable] = None, + optimizer: OptimizerCallable = None, + lr_scheduler: LRSchedulerCallable = None, + ): + super().__init__() + self.vae = vae + self.conditioner = conditioner + self.denoiser = denoiser + self.ema_denoiser = copy.deepcopy(self.denoiser) + self.diffusion_sampler = diffusion_sampler + self.diffusion_trainer = diffusion_trainer + self.ema_tracker = ema_tracker + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + # self.model_loader = ModelLoader() + + self._strict_loading = False + + def configure_model(self) -> None: + self.trainer.strategy.barrier() + # self.denoiser = self.model_loader.load(self.denoiser) + copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser) + + # self.denoiser = torch.compile(self.denoiser) + # disable grad for conditioner and vae + no_grad(self.conditioner) + no_grad(self.vae) + no_grad(self.diffusion_sampler) + no_grad(self.ema_denoiser) + + def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: + ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser) + return [ema_tracker] + + def configure_optimizers(self) -> OptimizerLRScheduler: + params_denoiser = filter_nograd_tensors(self.denoiser.parameters()) + params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters()) + optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser]) + if self.lr_scheduler is None: + return dict( + optimizer=optimizer + ) + else: + lr_scheduler = self.lr_scheduler(optimizer) + return dict( + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + + def training_step(self, batch, batch_idx): + raw_images, x, y = batch + with torch.no_grad(): + x = self.vae.encode(x) + condition, uncondition = self.conditioner(y) + loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition) + self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False) + return loss["loss"] + + def predict_step(self, batch, batch_idx): + xT, y, metadata = batch + with torch.no_grad(): + condition, uncondition = self.conditioner(y) + # Sample images: + samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition) + samples = self.vae.decode(samples) + # fp32 -1,1 -> uint8 0,255 + samples = fp2uint8(samples) + return samples + + def validation_step(self, batch, batch_idx): + samples = self.predict_step(batch, batch_idx) + return samples + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + if destination is None: + destination = {} + self._save_to_state_dict(destination, prefix, keep_vars) + self.denoiser.state_dict( + destination=destination, + prefix=prefix+"denoiser.", + keep_vars=keep_vars) + self.ema_denoiser.state_dict( + destination=destination, + prefix=prefix+"ema_denoiser.", + keep_vars=keep_vars) + self.diffusion_trainer.state_dict( + destination=destination, + prefix=prefix+"diffusion_trainer.", + keep_vars=keep_vars) + return destination \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/conditioner.py b/src/models/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..a68fad39109c9cf87460f606fb3c2abd6ffdb586 --- /dev/null +++ b/src/models/conditioner.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class BaseConditioner(nn.Module): + def __init__(self): + super(BaseConditioner, self).__init__() + + def _impl_condition(self, y): + ... + def _impl_uncondition(self, y): + ... + def __call__(self, y): + condition = self._impl_condition(y) + uncondition = self._impl_uncondition(y) + return condition, uncondition + +class LabelConditioner(BaseConditioner): + def __init__(self, null_class): + super().__init__() + self.null_condition = null_class + + def _impl_condition(self, y): + return torch.tensor(y).long().cuda() + + def _impl_uncondition(self, y): + return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda() \ No newline at end of file diff --git a/src/models/denoiser/__init__.py b/src/models/denoiser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/denoiser/decoupled_improved_dit.py b/src/models/denoiser/decoupled_improved_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..f20115cb45cb5a2d118b509785f798a85cfe6626 --- /dev/null +++ b/src/models/denoiser/decoupled_improved_dit.py @@ -0,0 +1,308 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class DDTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DDT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_encoder_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_encoder_blocks = num_encoder_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + DDTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_encoder_blocks): + s = self.blocks[i](s, c, pos, mask) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_encoder_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/improved_dit.py b/src/models/denoiser/improved_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..99e2f5ad46db998cacb44427e86da9820d7fab35 --- /dev/null +++ b/src/models/denoiser/improved_dit.py @@ -0,0 +1,301 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class DiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device, dtype): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device, dtype) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, masks=None): + if masks is None: + masks = [None, ]*self.num_blocks + if isinstance(masks, torch.Tensor): + masks = masks.unbind(0) + if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks: + masks = masks + [None]*(self.num_blocks-len(masks)) + + B, _, H, W = x.shape + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + x = self.x_embedder(x) + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype) + B, L, C = x.shape + t = self.t_embedder(t.view(-1)).view(B, -1, C) + y = self.y_embedder(y).view(B, 1, C) + condition = nn.functional.silu(t + y) + for i, block in enumerate(self.blocks): + x = block(x, condition, pos, masks[i]) + x = self.final_layer(x, condition) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x \ No newline at end of file diff --git a/src/models/encoder.py b/src/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7f96ab394384ce8966c1a6f487705c9ec91378 --- /dev/null +++ b/src/models/encoder.py @@ -0,0 +1,132 @@ +import torch +import copy +import os +import timm +import transformers +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from torchvision.transforms import Normalize + +class RandViT(nn.Module): + def __init__(self, model_id, weight_path:str=None): + super(RandViT, self).__init__() + self.encoder = timm.create_model( + model_id, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class MAE(nn.Module): + def __init__(self, model_id, weight_path:str): + super(MAE, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class DINO(nn.Module): + def __init__(self, model_id, weight_path:str): + super(DINO, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([ 0.0, + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([ 1.0, + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class CLIP(nn.Module): + def __init__(self, model_id, weight_path:str): + super(CLIP, self).__init__() + self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path) + self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size + self.shifts = nn.Parameter(torch.tensor([0.0, + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0, + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder(x)['last_hidden_state'][:, 1:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + + + +class DINOv2(nn.Module): + def __init__(self, model_id, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = transformers.Dinov2Model.from_pretrained(weight_path) + self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward(x)['last_hidden_state'][:, 1:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + return feature \ No newline at end of file diff --git a/src/models/vae.py b/src/models/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c47b08764384fb9bffc6f918b9a099a190248b85 --- /dev/null +++ b/src/models/vae.py @@ -0,0 +1,81 @@ +import torch +import subprocess +import lightning.pytorch as pl + +import logging + + +logger = logging.getLogger(__name__) +def class_fn_from_str(class_str): + class_module, from_class = class_str.rsplit(".", 1) + class_module = __import__(class_module, fromlist=[from_class]) + return getattr(class_module, from_class) + + +class BaseVAE(torch.nn.Module): + def __init__(self, scale=1.0, shift=0.0): + super().__init__() + self.model = torch.nn.Identity() + self.scale = scale + self.shift = shift + + def encode(self, x): + return x/self.scale+self.shift + + def decode(self, x): + return (x-self.shift)*self.scale + + +# very bad bugs with nearest sampling +class DownSampleVAE(BaseVAE): + def __init__(self, down_ratio, scale=1.0, shift=0.0): + super().__init__() + self.model = torch.nn.Identity() + self.scale = scale + self.shift = shift + self.down_ratio = down_ratio + + def encode(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False) + return x/self.scale+self.shift + + def decode(self, x): + x = (x-self.shift)*self.scale + x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False) + return x + + + +class LatentVAE(BaseVAE): + def __init__(self, precompute=False, weight_path:str=None): + super().__init__() + self.precompute = precompute + self.model = None + self.weight_path = weight_path + + from diffusers.models import AutoencoderKL + setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path)) + self.scaling_factor = self.model.config.scaling_factor + + @torch.no_grad() + def encode(self, x): + assert self.model is not None + if self.precompute: + return x.mul_(self.scaling_factor) + return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor) + + @torch.no_grad() + def decode(self, x): + assert self.model is not None + return self.model.decode(x.div_(self.scaling_factor)).sample + + +def uint82fp(x): + x = x.to(torch.float32) + x = (x - 127.5) / 127.5 + return x + +def fp2uint8(x): + x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8) + return x + diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/plugins/bd_env.py b/src/plugins/bd_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c1900e9c34422e58cf14dedf285a4e162cee62db --- /dev/null +++ b/src/plugins/bd_env.py @@ -0,0 +1,70 @@ +import torch +import os +import socket +from typing_extensions import override +from lightning.fabric.utilities.rank_zero import rank_zero_only +from lightning.fabric.plugins.environments.lightning import LightningEnvironment + + +class BDEnvironment(LightningEnvironment): + pass + # def __init__(self) -> None: + # super().__init__() + # self._global_rank: int = 0 + # self._world_size: int = 1 + # + # @property + # @override + # def creates_processes_externally(self) -> bool: + # """Returns whether the cluster creates the processes or not. + # + # If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the + # process launcher/job scheduler and Lightning will not launch new processes. + # + # """ + # return "LOCAL_RANK" in os.environ + # + # @staticmethod + # @override + # def detect() -> bool: + # assert "ARNOLD_WORKER_0_HOST" in os.environ.keys() + # assert "ARNOLD_WORKER_0_PORT" in os.environ.keys() + # return True + # + # @override + # def world_size(self) -> int: + # return self._world_size + # + # @override + # def set_world_size(self, size: int) -> None: + # self._world_size = size + # + # @override + # def global_rank(self) -> int: + # return self._global_rank + # + # @override + # def set_global_rank(self, rank: int) -> None: + # self._global_rank = rank + # rank_zero_only.rank = rank + # + # @override + # def local_rank(self) -> int: + # return int(os.environ.get("LOCAL_RANK", 0)) + # + # @override + # def node_rank(self) -> int: + # return int(os.environ.get("ARNOLD_ID")) + # + # @override + # def teardown(self) -> None: + # if "WORLD_SIZE" in os.environ: + # del os.environ["WORLD_SIZE"] + # + # @property + # def main_address(self) -> str: + # return os.environ.get("ARNOLD_WORKER_0_HOST") + # + # @property + # def main_port(self) -> int: + # return int(os.environ.get("ARNOLD_WORKER_0_PORT")) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/utils/copy.py b/src/utils/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..62cd89da7ffd0f3b65fd0206b9c646f8df5e64c4 --- /dev/null +++ b/src/utils/copy.py @@ -0,0 +1,13 @@ +import torch + +@torch.no_grad() +def copy_params(src_model, dst_model): + for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()): + dst_param.data.copy_(src_param.data) + +@torch.no_grad() +def swap_tensors(tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) \ No newline at end of file diff --git a/src/utils/model_loader.py b/src/utils/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..7d991665c88e8bd15eae9d8daecb580e9a6312fb --- /dev/null +++ b/src/utils/model_loader.py @@ -0,0 +1,29 @@ +from typing import Dict, Any, Optional + +import torch +import torch.nn as nn +from lightning.fabric.utilities.types import _PATH + + +import logging +logger = logging.getLogger(__name__) + +class ModelLoader: + def __init__(self,): + super().__init__() + + def load(self, denoiser, prefix=""): + if denoiser.weight_path: + weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu')) + + if denoiser.load_ema: + prefix = "ema_denoiser." + prefix + else: + prefix = "denoiser." + prefix + + for k, v in denoiser.state_dict().items(): + try: + v.copy_(weight["state_dict"][prefix+k]) + except: + logger.warning(f"Failed to copy {prefix+k} to denoiser weight") + return denoiser \ No newline at end of file diff --git a/src/utils/no_grad.py b/src/utils/no_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd71dedf18050634fa69574a1e6f9acf7d7131e --- /dev/null +++ b/src/utils/no_grad.py @@ -0,0 +1,16 @@ +import torch + +@torch.no_grad() +def no_grad(net): + for param in net.parameters(): + param.requires_grad = False + net.eval() + return net + +@torch.no_grad() +def filter_nograd_tensors(params_list): + filtered_params_list = [] + for param in params_list: + if param.requires_grad: + filtered_params_list.append(param) + return filtered_params_list \ No newline at end of file diff --git a/src/utils/patch_bugs.py b/src/utils/patch_bugs.py new file mode 100644 index 0000000000000000000000000000000000000000..db9a174793eac1b20e25adc8a2103dba961356ed --- /dev/null +++ b/src/utils/patch_bugs.py @@ -0,0 +1,17 @@ +import torch +import lightning.pytorch.loggers.wandb as wandb + +setattr(wandb, '_WANDB_AVAILABLE', True) +torch.set_float32_matmul_precision('medium') + +import logging +logger = logging.getLogger("wandb") +logger.setLevel(logging.WARNING) + +import os +os.environ["NCCL_DEBUG"] = "WARN" +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=UserWarning)