Spaces:
Running
on
Zero
Running
on
Zero
wangshuai6
commited on
Commit
·
9e426da
1
Parent(s):
df79fa7
init space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +170 -0
- configs/repa_improved_ddt_xlen22de6_256.yaml +108 -0
- configs/repa_improved_ddt_xlen22de6_512.yaml +108 -0
- configs/repa_improved_dit_large.yaml +99 -0
- configs/repa_improved_dit_xl.yaml +99 -0
- requirements.txt +6 -0
- src/__init__.py +0 -0
- src/callbacks/__init__.py +0 -0
- src/callbacks/grad.py +22 -0
- src/callbacks/model_checkpoint.py +21 -0
- src/callbacks/save_images.py +105 -0
- src/callbacks/simple_ema.py +79 -0
- src/data/__init__.py +1 -0
- src/data/dataset/__init__.py +0 -0
- src/data/dataset/celeba.py +11 -0
- src/data/dataset/imagenet.py +82 -0
- src/data/dataset/metric_dataset.py +82 -0
- src/data/dataset/randn.py +41 -0
- src/data/var_training.py +145 -0
- src/diffusion/__init__.py +0 -0
- src/diffusion/base/guidance.py +60 -0
- src/diffusion/base/sampling.py +31 -0
- src/diffusion/base/scheduling.py +32 -0
- src/diffusion/base/training.py +29 -0
- src/diffusion/ddpm/ddim_sampling.py +40 -0
- src/diffusion/ddpm/scheduling.py +102 -0
- src/diffusion/ddpm/training.py +83 -0
- src/diffusion/ddpm/vp_sampling.py +59 -0
- src/diffusion/flow_matching/adam_sampling.py +107 -0
- src/diffusion/flow_matching/sampling.py +179 -0
- src/diffusion/flow_matching/scheduling.py +39 -0
- src/diffusion/flow_matching/training.py +55 -0
- src/diffusion/flow_matching/training_cos.py +59 -0
- src/diffusion/flow_matching/training_repa.py +137 -0
- src/diffusion/pre_integral.py +143 -0
- src/diffusion/stateful_flow_matching/adam_sampling.py +112 -0
- src/diffusion/stateful_flow_matching/sampling.py +103 -0
- src/diffusion/stateful_flow_matching/scheduling.py +39 -0
- src/diffusion/stateful_flow_matching/sharing_sampling.py +149 -0
- src/diffusion/stateful_flow_matching/training.py +55 -0
- src/diffusion/stateful_flow_matching/training_repa.py +152 -0
- src/lightning_data.py +162 -0
- src/lightning_model.py +123 -0
- src/models/__init__.py +0 -0
- src/models/conditioner.py +26 -0
- src/models/denoiser/__init__.py +0 -0
- src/models/denoiser/decoupled_improved_dit.py +308 -0
- src/models/denoiser/improved_dit.py +301 -0
- src/models/encoder.py +132 -0
- src/models/vae.py +81 -0
app.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vae:
|
2 |
+
# class_path: src.models.vae.LatentVAE
|
3 |
+
# init_args:
|
4 |
+
# precompute: true
|
5 |
+
# weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
|
6 |
+
# denoiser:
|
7 |
+
# class_path: src.models.denoiser.decoupled_improved_dit.DDT
|
8 |
+
# init_args:
|
9 |
+
# in_channels: 4
|
10 |
+
# patch_size: 2
|
11 |
+
# num_groups: 16
|
12 |
+
# hidden_size: &hidden_dim 1152
|
13 |
+
# num_blocks: 28
|
14 |
+
# num_encoder_blocks: 22
|
15 |
+
# num_classes: 1000
|
16 |
+
# conditioner:
|
17 |
+
# class_path: src.models.conditioner.LabelConditioner
|
18 |
+
# init_args:
|
19 |
+
# null_class: 1000
|
20 |
+
# diffusion_sampler:
|
21 |
+
# class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
22 |
+
# init_args:
|
23 |
+
# num_steps: 250
|
24 |
+
# guidance: 3.0
|
25 |
+
# state_refresh_rate: 1
|
26 |
+
# guidance_interval_min: 0.3
|
27 |
+
# guidance_interval_max: 1.0
|
28 |
+
# timeshift: 1.0
|
29 |
+
# last_step: 0.04
|
30 |
+
# scheduler: *scheduler
|
31 |
+
# w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
32 |
+
# guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
33 |
+
# step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import argparse
|
37 |
+
from omegaconf import OmegaConf
|
38 |
+
from src.models.vae import fp2uint8
|
39 |
+
from src.diffusion.base.guidance import simple_guidance_fn
|
40 |
+
from src.diffusion.stateful_flow_matching.sharing_sampling import EulerSampler
|
41 |
+
from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler
|
42 |
+
from PIL import Image
|
43 |
+
import gradio as gr
|
44 |
+
from huggingface_hub import snapshot_download
|
45 |
+
|
46 |
+
|
47 |
+
def instantiate_class(config):
|
48 |
+
kwargs = config.get("init_args", {})
|
49 |
+
class_module, class_name = config["class_path"].rsplit(".", 1)
|
50 |
+
module = __import__(class_module, fromlist=[class_name])
|
51 |
+
args_class = getattr(module, class_name)
|
52 |
+
return args_class(**kwargs)
|
53 |
+
|
54 |
+
def load_model(weight_dict, denosier):
|
55 |
+
prefix = "ema_denoiser."
|
56 |
+
for k, v in denoiser.state_dict().items():
|
57 |
+
try:
|
58 |
+
v.copy_(weight_dict["state_dict"][prefix + k])
|
59 |
+
except:
|
60 |
+
print(f"Failed to copy {prefix + k} to denoiser weight")
|
61 |
+
return denoiser
|
62 |
+
|
63 |
+
|
64 |
+
class Pipeline:
|
65 |
+
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
|
66 |
+
self.vae = vae
|
67 |
+
self.denoiser = denoiser
|
68 |
+
self.conditioner = conditioner
|
69 |
+
self.diffusion_sampler = diffusion_sampler
|
70 |
+
self.resolution = resolution
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
74 |
+
def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
|
75 |
+
self.diffusion_sampler.num_steps = num_steps
|
76 |
+
self.diffusion_sampler.guidance = guidance
|
77 |
+
self.diffusion_sampler.state_refresh_rate = state_refresh_rate
|
78 |
+
self.diffusion_sampler.guidance_interval_min = guidance_interval_min
|
79 |
+
self.diffusion_sampler.guidance_interval_max = guidance_interval_max
|
80 |
+
self.diffusion_sampler.timeshift = timeshift
|
81 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
82 |
+
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
83 |
+
with torch.no_grad():
|
84 |
+
condition, uncondition = conditioner([y,]*num_images)
|
85 |
+
# Sample images:
|
86 |
+
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
87 |
+
samples = vae.decode(samples)
|
88 |
+
# fp32 -1,1 -> uint8 0,255
|
89 |
+
samples = fp2uint8(samples)
|
90 |
+
samples = samples.permute(0, 2, 3, 1).cpu().numpy()
|
91 |
+
images = []
|
92 |
+
for i in range(num_images):
|
93 |
+
image = Image.fromarray(samples[i])
|
94 |
+
images.append(image)
|
95 |
+
return images
|
96 |
+
|
97 |
+
import os
|
98 |
+
import spaces
|
99 |
+
if __name__ == "__main__":
|
100 |
+
parser = argparse.ArgumentParser()
|
101 |
+
parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml")
|
102 |
+
parser.add_argument("--resolution", type=int, default=512)
|
103 |
+
parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R512")
|
104 |
+
parser.add_argument("--ckpt_path", type=str, default="models")
|
105 |
+
args = parser.parse_args()
|
106 |
+
|
107 |
+
if not os.path.exists(args.ckpt_path):
|
108 |
+
snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path)
|
109 |
+
|
110 |
+
config = OmegaConf.load(args.config)
|
111 |
+
vae_config = config.model.vae
|
112 |
+
diffusion_sampler_config = config.model.diffusion_sampler
|
113 |
+
denoiser_config = config.model.denoiser
|
114 |
+
conditioner_config = config.model.conditioner
|
115 |
+
|
116 |
+
vae = instantiate_class(vae_config)
|
117 |
+
denoiser = instantiate_class(denoiser_config)
|
118 |
+
conditioner = instantiate_class(conditioner_config)
|
119 |
+
|
120 |
+
|
121 |
+
diffusion_sampler = EulerSampler(
|
122 |
+
scheduler=LinearScheduler(),
|
123 |
+
w_scheduler=LinearScheduler(),
|
124 |
+
guidance_fn=simple_guidance_fn,
|
125 |
+
num_steps=50,
|
126 |
+
guidance=3.0,
|
127 |
+
state_refresh_rate=1,
|
128 |
+
guidance_interval_min=0.3,
|
129 |
+
guidance_interval_max=1.0,
|
130 |
+
timeshift=1.0
|
131 |
+
)
|
132 |
+
ckpt_path = os.path.join(args.ckpt_path, "model.ckpt")
|
133 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
134 |
+
denoiser = load_model(ckpt, denoiser)
|
135 |
+
denoiser = denoiser.cuda()
|
136 |
+
vae = vae.cuda()
|
137 |
+
denoiser.eval()
|
138 |
+
|
139 |
+
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution)
|
140 |
+
|
141 |
+
with gr.Blocks() as demo:
|
142 |
+
gr.Markdown("DDT")
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column(scale=1):
|
145 |
+
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
146 |
+
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
147 |
+
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
|
148 |
+
label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948)
|
149 |
+
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
150 |
+
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
151 |
+
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
|
152 |
+
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0)
|
153 |
+
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
|
154 |
+
with gr.Column(scale=2):
|
155 |
+
btn = gr.Button("Generate")
|
156 |
+
output = gr.Gallery(label="Images")
|
157 |
+
|
158 |
+
btn.click(fn=pipeline,
|
159 |
+
inputs=[
|
160 |
+
label,
|
161 |
+
num_images,
|
162 |
+
seed,
|
163 |
+
num_steps,
|
164 |
+
guidance,
|
165 |
+
state_refresh_rate,
|
166 |
+
guidance_interval_min,
|
167 |
+
guidance_interval_max,
|
168 |
+
timeshift
|
169 |
+
], outputs=[output])
|
170 |
+
demo.launch(server_name="0.0.0.0", server_port=7861)
|
configs/repa_improved_ddt_xlen22de6_256.yaml
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# lightning.pytorch==2.4.0
|
2 |
+
seed_everything: true
|
3 |
+
tags:
|
4 |
+
exp: &exp repa_flatten_condit22_dit6_fixt_xl
|
5 |
+
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
6 |
+
huggingface_cache_dir: null
|
7 |
+
trainer:
|
8 |
+
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
9 |
+
accelerator: auto
|
10 |
+
strategy: auto
|
11 |
+
devices: auto
|
12 |
+
num_nodes: 1
|
13 |
+
precision: bf16-mixed
|
14 |
+
logger:
|
15 |
+
class_path: lightning.pytorch.loggers.WandbLogger
|
16 |
+
init_args:
|
17 |
+
project: universal_flow
|
18 |
+
name: *exp
|
19 |
+
num_sanity_val_steps: 0
|
20 |
+
max_steps: 4000000
|
21 |
+
val_check_interval: 4000000
|
22 |
+
check_val_every_n_epoch: null
|
23 |
+
log_every_n_steps: 50
|
24 |
+
deterministic: null
|
25 |
+
inference_mode: true
|
26 |
+
use_distributed_sampler: false
|
27 |
+
callbacks:
|
28 |
+
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
29 |
+
init_args:
|
30 |
+
every_n_train_steps: 10000
|
31 |
+
save_top_k: -1
|
32 |
+
save_last: true
|
33 |
+
- class_path: src.callbacks.save_images.SaveImagesHook
|
34 |
+
init_args:
|
35 |
+
save_dir: val
|
36 |
+
plugins:
|
37 |
+
- src.plugins.bd_env.BDEnvironment
|
38 |
+
model:
|
39 |
+
vae:
|
40 |
+
class_path: src.models.vae.LatentVAE
|
41 |
+
init_args:
|
42 |
+
precompute: true
|
43 |
+
weight_path: stabilityai/sd-vae-ft-ema
|
44 |
+
denoiser:
|
45 |
+
class_path: src.models.denoiser.decoupled_improved_dit.DDT
|
46 |
+
init_args:
|
47 |
+
in_channels: 4
|
48 |
+
patch_size: 2
|
49 |
+
num_groups: 16
|
50 |
+
hidden_size: &hidden_dim 1152
|
51 |
+
num_blocks: 28
|
52 |
+
num_encoder_blocks: 22
|
53 |
+
num_classes: 1000
|
54 |
+
conditioner:
|
55 |
+
class_path: src.models.conditioner.LabelConditioner
|
56 |
+
init_args:
|
57 |
+
null_class: 1000
|
58 |
+
diffusion_trainer:
|
59 |
+
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
60 |
+
init_args:
|
61 |
+
lognorm_t: true
|
62 |
+
encoder_weight_path: dinov2_vitb14
|
63 |
+
align_layer: 8
|
64 |
+
proj_denoiser_dim: *hidden_dim
|
65 |
+
proj_hidden_dim: *hidden_dim
|
66 |
+
proj_encoder_dim: 768
|
67 |
+
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
68 |
+
diffusion_sampler:
|
69 |
+
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
70 |
+
init_args:
|
71 |
+
num_steps: 250
|
72 |
+
guidance: 2.0
|
73 |
+
timeshift: 1.0
|
74 |
+
state_refresh_rate: 1
|
75 |
+
guidance_interval_min: 0.3
|
76 |
+
guidance_interval_max: 1.0
|
77 |
+
scheduler: *scheduler
|
78 |
+
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
79 |
+
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
80 |
+
last_step: 0.04
|
81 |
+
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
82 |
+
ema_tracker:
|
83 |
+
class_path: src.callbacks.simple_ema.SimpleEMA
|
84 |
+
init_args:
|
85 |
+
decay: 0.9999
|
86 |
+
optimizer:
|
87 |
+
class_path: torch.optim.AdamW
|
88 |
+
init_args:
|
89 |
+
lr: 1e-4
|
90 |
+
betas:
|
91 |
+
- 0.9
|
92 |
+
- 0.95
|
93 |
+
weight_decay: 0.0
|
94 |
+
data:
|
95 |
+
train_dataset: imagenet256
|
96 |
+
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
97 |
+
train_image_size: 256
|
98 |
+
train_batch_size: 16
|
99 |
+
eval_max_num_instances: 50000
|
100 |
+
pred_batch_size: 64
|
101 |
+
pred_num_workers: 4
|
102 |
+
pred_seeds: null
|
103 |
+
pred_selected_classes: null
|
104 |
+
num_classes: 1000
|
105 |
+
latent_shape:
|
106 |
+
- 4
|
107 |
+
- 32
|
108 |
+
- 32
|
configs/repa_improved_ddt_xlen22de6_512.yaml
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# lightning.pytorch==2.4.0
|
2 |
+
seed_everything: true
|
3 |
+
tags:
|
4 |
+
exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl
|
5 |
+
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
6 |
+
huggingface_cache_dir: null
|
7 |
+
trainer:
|
8 |
+
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
9 |
+
accelerator: auto
|
10 |
+
strategy: auto
|
11 |
+
devices: auto
|
12 |
+
num_nodes: 1
|
13 |
+
precision: bf16-mixed
|
14 |
+
logger:
|
15 |
+
class_path: lightning.pytorch.loggers.WandbLogger
|
16 |
+
init_args:
|
17 |
+
project: universal_flow
|
18 |
+
name: *exp
|
19 |
+
num_sanity_val_steps: 0
|
20 |
+
max_steps: 4000000
|
21 |
+
val_check_interval: 4000000
|
22 |
+
check_val_every_n_epoch: null
|
23 |
+
log_every_n_steps: 50
|
24 |
+
deterministic: null
|
25 |
+
inference_mode: true
|
26 |
+
use_distributed_sampler: false
|
27 |
+
callbacks:
|
28 |
+
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
29 |
+
init_args:
|
30 |
+
every_n_train_steps: 10000
|
31 |
+
save_top_k: -1
|
32 |
+
save_last: true
|
33 |
+
- class_path: src.callbacks.save_images.SaveImagesHook
|
34 |
+
init_args:
|
35 |
+
save_dir: val
|
36 |
+
plugins:
|
37 |
+
- src.plugins.bd_env.BDEnvironment
|
38 |
+
model:
|
39 |
+
vae:
|
40 |
+
class_path: src.models.vae.LatentVAE
|
41 |
+
init_args:
|
42 |
+
precompute: true
|
43 |
+
weight_path: stabilityai/sd-vae-ft-ema
|
44 |
+
denoiser:
|
45 |
+
class_path: src.models.denoiser.decoupled_improved_dit.DDT
|
46 |
+
init_args:
|
47 |
+
in_channels: 4
|
48 |
+
patch_size: 2
|
49 |
+
num_groups: 16
|
50 |
+
hidden_size: &hidden_dim 1152
|
51 |
+
num_blocks: 28
|
52 |
+
num_encoder_blocks: 22
|
53 |
+
num_classes: 1000
|
54 |
+
conditioner:
|
55 |
+
class_path: src.models.conditioner.LabelConditioner
|
56 |
+
init_args:
|
57 |
+
null_class: 1000
|
58 |
+
diffusion_trainer:
|
59 |
+
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
60 |
+
init_args:
|
61 |
+
lognorm_t: true
|
62 |
+
encoder_weight_path: dinov2_vitb14
|
63 |
+
align_layer: 8
|
64 |
+
proj_denoiser_dim: *hidden_dim
|
65 |
+
proj_hidden_dim: *hidden_dim
|
66 |
+
proj_encoder_dim: 768
|
67 |
+
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
68 |
+
diffusion_sampler:
|
69 |
+
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
70 |
+
init_args:
|
71 |
+
num_steps: 250
|
72 |
+
guidance: 3.0
|
73 |
+
state_refresh_rate: 1
|
74 |
+
guidance_interval_min: 0.3
|
75 |
+
guidance_interval_max: 1.0
|
76 |
+
timeshift: 1.0
|
77 |
+
last_step: 0.04
|
78 |
+
scheduler: *scheduler
|
79 |
+
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
80 |
+
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
81 |
+
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
82 |
+
ema_tracker:
|
83 |
+
class_path: src.callbacks.simple_ema.SimpleEMA
|
84 |
+
init_args:
|
85 |
+
decay: 0.9999
|
86 |
+
optimizer:
|
87 |
+
class_path: torch.optim.AdamW
|
88 |
+
init_args:
|
89 |
+
lr: 1e-4
|
90 |
+
betas:
|
91 |
+
- 0.9
|
92 |
+
- 0.95
|
93 |
+
weight_decay: 0.0
|
94 |
+
data:
|
95 |
+
train_dataset: imagenet512
|
96 |
+
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
97 |
+
train_image_size: 512
|
98 |
+
train_batch_size: 16
|
99 |
+
eval_max_num_instances: 50000
|
100 |
+
pred_batch_size: 32
|
101 |
+
pred_num_workers: 4
|
102 |
+
pred_seeds: null
|
103 |
+
pred_selected_classes: null
|
104 |
+
num_classes: 1000
|
105 |
+
latent_shape:
|
106 |
+
- 4
|
107 |
+
- 64
|
108 |
+
- 64
|
configs/repa_improved_dit_large.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# lightning.pytorch==2.4.0
|
2 |
+
seed_everything: true
|
3 |
+
tags:
|
4 |
+
exp: &exp repa_improved_dit_large
|
5 |
+
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
6 |
+
huggingface_cache_dir: null
|
7 |
+
trainer:
|
8 |
+
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
9 |
+
accelerator: auto
|
10 |
+
strategy: auto
|
11 |
+
devices: auto
|
12 |
+
num_nodes: 1
|
13 |
+
precision: bf16-mixed
|
14 |
+
logger:
|
15 |
+
class_path: lightning.pytorch.loggers.WandbLogger
|
16 |
+
init_args:
|
17 |
+
project: universal_flow
|
18 |
+
name: *exp
|
19 |
+
num_sanity_val_steps: 0
|
20 |
+
max_steps: 400000
|
21 |
+
val_check_interval: 100000
|
22 |
+
check_val_every_n_epoch: null
|
23 |
+
log_every_n_steps: 50
|
24 |
+
deterministic: null
|
25 |
+
inference_mode: true
|
26 |
+
use_distributed_sampler: false
|
27 |
+
callbacks:
|
28 |
+
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
29 |
+
init_args:
|
30 |
+
every_n_train_steps: 10000
|
31 |
+
save_top_k: -1
|
32 |
+
save_last: true
|
33 |
+
- class_path: src.callbacks.save_images.SaveImagesHook
|
34 |
+
init_args:
|
35 |
+
save_dir: val
|
36 |
+
plugins:
|
37 |
+
- src.plugins.bd_env.BDEnvironment
|
38 |
+
model:
|
39 |
+
vae:
|
40 |
+
class_path: src.models.vae.LatentVAE
|
41 |
+
init_args:
|
42 |
+
precompute: true
|
43 |
+
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
|
44 |
+
denoiser:
|
45 |
+
class_path: src.models.denoiser.improved_dit.DiT
|
46 |
+
init_args:
|
47 |
+
in_channels: 4
|
48 |
+
patch_size: 2
|
49 |
+
num_groups: 16
|
50 |
+
hidden_size: &hidden_dim 1024
|
51 |
+
num_blocks: 24
|
52 |
+
num_classes: 1000
|
53 |
+
conditioner:
|
54 |
+
class_path: src.models.conditioner.LabelConditioner
|
55 |
+
init_args:
|
56 |
+
null_class: 1000
|
57 |
+
diffusion_trainer:
|
58 |
+
class_path: src.diffusion.flow_matching.training_repa.REPATrainer
|
59 |
+
init_args:
|
60 |
+
lognorm_t: true
|
61 |
+
encoder_weight_path: dinov2_vitb14
|
62 |
+
align_layer: 8
|
63 |
+
proj_denoiser_dim: *hidden_dim
|
64 |
+
proj_hidden_dim: *hidden_dim
|
65 |
+
proj_encoder_dim: 768
|
66 |
+
scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
|
67 |
+
diffusion_sampler:
|
68 |
+
class_path: src.diffusion.flow_matching.sampling.EulerSampler
|
69 |
+
init_args:
|
70 |
+
num_steps: 250
|
71 |
+
guidance: 1.00
|
72 |
+
scheduler: *scheduler
|
73 |
+
w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
|
74 |
+
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
75 |
+
step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
|
76 |
+
ema_tracker:
|
77 |
+
class_path: src.callbacks.simple_ema.SimpleEMA
|
78 |
+
init_args:
|
79 |
+
decay: 0.9999
|
80 |
+
optimizer:
|
81 |
+
class_path: torch.optim.AdamW
|
82 |
+
init_args:
|
83 |
+
lr: 1e-4
|
84 |
+
weight_decay: 0.0
|
85 |
+
data:
|
86 |
+
train_dataset: imagenet256
|
87 |
+
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
88 |
+
train_image_size: 256
|
89 |
+
train_batch_size: 32
|
90 |
+
eval_max_num_instances: 50000
|
91 |
+
pred_batch_size: 64
|
92 |
+
pred_num_workers: 4
|
93 |
+
pred_seeds: null
|
94 |
+
pred_selected_classes: null
|
95 |
+
num_classes: 1000
|
96 |
+
latent_shape:
|
97 |
+
- 4
|
98 |
+
- 32
|
99 |
+
- 32
|
configs/repa_improved_dit_xl.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# lightning.pytorch==2.4.0
|
2 |
+
seed_everything: true
|
3 |
+
tags:
|
4 |
+
exp: &exp repa_improved_dit_xlen22de6_512
|
5 |
+
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
6 |
+
huggingface_cache_dir: null
|
7 |
+
trainer:
|
8 |
+
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
9 |
+
accelerator: auto
|
10 |
+
strategy: auto
|
11 |
+
devices: auto
|
12 |
+
num_nodes: 1
|
13 |
+
precision: bf16-mixed
|
14 |
+
logger:
|
15 |
+
class_path: lightning.pytorch.loggers.WandbLogger
|
16 |
+
init_args:
|
17 |
+
project: universal_flow
|
18 |
+
name: *exp
|
19 |
+
num_sanity_val_steps: 0
|
20 |
+
max_steps: 400000
|
21 |
+
val_check_interval: 100000
|
22 |
+
check_val_every_n_epoch: null
|
23 |
+
log_every_n_steps: 50
|
24 |
+
deterministic: null
|
25 |
+
inference_mode: true
|
26 |
+
use_distributed_sampler: false
|
27 |
+
callbacks:
|
28 |
+
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
29 |
+
init_args:
|
30 |
+
every_n_train_steps: 10000
|
31 |
+
save_top_k: -1
|
32 |
+
save_last: true
|
33 |
+
- class_path: src.callbacks.save_images.SaveImagesHook
|
34 |
+
init_args:
|
35 |
+
save_dir: val
|
36 |
+
plugins:
|
37 |
+
- src.plugins.bd_env.BDEnvironment
|
38 |
+
model:
|
39 |
+
vae:
|
40 |
+
class_path: src.models.vae.LatentVAE
|
41 |
+
init_args:
|
42 |
+
precompute: true
|
43 |
+
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
|
44 |
+
denoiser:
|
45 |
+
class_path: src.models.denoiser.improved_dit.DiT
|
46 |
+
init_args:
|
47 |
+
in_channels: 4
|
48 |
+
patch_size: 2
|
49 |
+
num_groups: 16
|
50 |
+
hidden_size: &hidden_dim 1152
|
51 |
+
num_blocks: 28
|
52 |
+
num_classes: 1000
|
53 |
+
conditioner:
|
54 |
+
class_path: src.models.conditioner.LabelConditioner
|
55 |
+
init_args:
|
56 |
+
null_class: 1000
|
57 |
+
diffusion_trainer:
|
58 |
+
class_path: src.diffusion.flow_matching.training_repa.REPATrainer
|
59 |
+
init_args:
|
60 |
+
lognorm_t: true
|
61 |
+
encoder_weight_path: dinov2_vitb14
|
62 |
+
align_layer: 8
|
63 |
+
proj_denoiser_dim: *hidden_dim
|
64 |
+
proj_hidden_dim: *hidden_dim
|
65 |
+
proj_encoder_dim: 768
|
66 |
+
scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
|
67 |
+
diffusion_sampler:
|
68 |
+
class_path: src.diffusion.flow_matching.sampling.EulerSampler
|
69 |
+
init_args:
|
70 |
+
num_steps: 250
|
71 |
+
guidance: 1.00
|
72 |
+
scheduler: *scheduler
|
73 |
+
w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
|
74 |
+
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
75 |
+
step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
|
76 |
+
ema_tracker:
|
77 |
+
class_path: src.callbacks.simple_ema.SimpleEMA
|
78 |
+
init_args:
|
79 |
+
decay: 0.9999
|
80 |
+
optimizer:
|
81 |
+
class_path: torch.optim.AdamW
|
82 |
+
init_args:
|
83 |
+
lr: 1e-4
|
84 |
+
weight_decay: 0.0
|
85 |
+
data:
|
86 |
+
train_dataset: imagenet256
|
87 |
+
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
88 |
+
train_image_size: 256
|
89 |
+
train_batch_size: 32
|
90 |
+
eval_max_num_instances: 50000
|
91 |
+
pred_batch_size: 64
|
92 |
+
pred_num_workers: 4
|
93 |
+
pred_seeds: null
|
94 |
+
pred_selected_classes: null
|
95 |
+
num_classes: 1000
|
96 |
+
latent_shape:
|
97 |
+
- 4
|
98 |
+
- 32
|
99 |
+
- 32
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning==2.5.0.post0
|
2 |
+
omegaconf==2.3.0
|
3 |
+
torch==2.3.0
|
4 |
+
diffusers==0.30.0
|
5 |
+
jsonargparse[signatures]>=4.27.7
|
6 |
+
accelerate
|
src/__init__.py
ADDED
File without changes
|
src/callbacks/__init__.py
ADDED
File without changes
|
src/callbacks/grad.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lightning.pytorch as pl
|
3 |
+
from lightning.pytorch.utilities import grad_norm
|
4 |
+
from torch.optim import Optimizer
|
5 |
+
|
6 |
+
class GradientMonitor(pl.Callback):
|
7 |
+
"""Logs the gradient norm"""
|
8 |
+
|
9 |
+
def __init__(self, norm_type: int = 2):
|
10 |
+
norm_type = float(norm_type)
|
11 |
+
if norm_type <= 0:
|
12 |
+
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
|
13 |
+
self.norm_type = norm_type
|
14 |
+
|
15 |
+
def on_before_optimizer_step(
|
16 |
+
self, trainer: "pl.Trainer",
|
17 |
+
pl_module: "pl.LightningModule",
|
18 |
+
optimizer: Optimizer
|
19 |
+
) -> None:
|
20 |
+
norms = grad_norm(pl_module, norm_type=self.norm_type)
|
21 |
+
max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
|
22 |
+
pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]})
|
src/callbacks/model_checkpoint.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from typing import Optional, Dict, Any
|
3 |
+
|
4 |
+
import lightning.pytorch as pl
|
5 |
+
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
6 |
+
from soupsieve.util import lower
|
7 |
+
|
8 |
+
|
9 |
+
class CheckpointHook(ModelCheckpoint):
|
10 |
+
"""Save checkpoint with only the incremental part of the model"""
|
11 |
+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
12 |
+
self.dirpath = trainer.default_root_dir
|
13 |
+
self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
|
14 |
+
pl_module.strict_loading = False
|
15 |
+
|
16 |
+
def on_save_checkpoint(
|
17 |
+
self, trainer: "pl.Trainer",
|
18 |
+
pl_module: "pl.LightningModule",
|
19 |
+
checkpoint: Dict[str, Any]
|
20 |
+
) -> None:
|
21 |
+
del checkpoint["callbacks"]
|
src/callbacks/save_images.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning.pytorch as pl
|
2 |
+
from lightning.pytorch import Callback
|
3 |
+
|
4 |
+
|
5 |
+
import os.path
|
6 |
+
import numpy
|
7 |
+
from PIL import Image
|
8 |
+
from typing import Sequence, Any, Dict
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
|
11 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
12 |
+
from lightning_utilities.core.rank_zero import rank_zero_info
|
13 |
+
|
14 |
+
def process_fn(image, path):
|
15 |
+
Image.fromarray(image).save(path)
|
16 |
+
|
17 |
+
class SaveImagesHook(Callback):
|
18 |
+
def __init__(self, save_dir="val", max_save_num=0, compressed=True):
|
19 |
+
self.save_dir = save_dir
|
20 |
+
self.max_save_num = max_save_num
|
21 |
+
self.compressed = compressed
|
22 |
+
|
23 |
+
def save_start(self, target_dir):
|
24 |
+
self.target_dir = target_dir
|
25 |
+
self.executor_pool = ThreadPoolExecutor(max_workers=8)
|
26 |
+
if not os.path.exists(self.target_dir):
|
27 |
+
os.makedirs(self.target_dir, exist_ok=True)
|
28 |
+
else:
|
29 |
+
if os.listdir(target_dir) and "debug" not in str(target_dir):
|
30 |
+
raise FileExistsError(f'{self.target_dir} already exists and not empty!')
|
31 |
+
self.samples = []
|
32 |
+
self._have_saved_num = 0
|
33 |
+
rank_zero_info(f"Save images to {self.target_dir}")
|
34 |
+
|
35 |
+
def save_image(self, images, filenames):
|
36 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
37 |
+
for sample, filename in zip(images, filenames):
|
38 |
+
if isinstance(filename, Sequence):
|
39 |
+
filename = filename[0]
|
40 |
+
path = f'{self.target_dir}/{filename}'
|
41 |
+
if self._have_saved_num >= self.max_save_num:
|
42 |
+
break
|
43 |
+
self.executor_pool.submit(process_fn, sample, path)
|
44 |
+
self._have_saved_num += 1
|
45 |
+
|
46 |
+
def process_batch(
|
47 |
+
self,
|
48 |
+
trainer: "pl.Trainer",
|
49 |
+
pl_module: "pl.LightningModule",
|
50 |
+
samples: STEP_OUTPUT,
|
51 |
+
batch: Any,
|
52 |
+
) -> None:
|
53 |
+
b, c, h, w = samples.shape
|
54 |
+
xT, y, metadata = batch
|
55 |
+
all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
|
56 |
+
self.save_image(samples, metadata)
|
57 |
+
if trainer.is_global_zero:
|
58 |
+
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
|
59 |
+
self.samples.append(all_samples)
|
60 |
+
|
61 |
+
def save_end(self):
|
62 |
+
if self.compressed and len(self.samples) > 0:
|
63 |
+
samples = numpy.concatenate(self.samples)
|
64 |
+
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
|
65 |
+
self.executor_pool.shutdown(wait=True)
|
66 |
+
self.samples = []
|
67 |
+
self.target_dir = None
|
68 |
+
self._have_saved_num = 0
|
69 |
+
self.executor_pool = None
|
70 |
+
|
71 |
+
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
72 |
+
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
|
73 |
+
self.save_start(target_dir)
|
74 |
+
|
75 |
+
def on_validation_batch_end(
|
76 |
+
self,
|
77 |
+
trainer: "pl.Trainer",
|
78 |
+
pl_module: "pl.LightningModule",
|
79 |
+
outputs: STEP_OUTPUT,
|
80 |
+
batch: Any,
|
81 |
+
batch_idx: int,
|
82 |
+
dataloader_idx: int = 0,
|
83 |
+
) -> None:
|
84 |
+
return self.process_batch(trainer, pl_module, outputs, batch)
|
85 |
+
|
86 |
+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
87 |
+
self.save_end()
|
88 |
+
|
89 |
+
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
90 |
+
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
|
91 |
+
self.save_start(target_dir)
|
92 |
+
|
93 |
+
def on_predict_batch_end(
|
94 |
+
self,
|
95 |
+
trainer: "pl.Trainer",
|
96 |
+
pl_module: "pl.LightningModule",
|
97 |
+
samples: Any,
|
98 |
+
batch: Any,
|
99 |
+
batch_idx: int,
|
100 |
+
dataloader_idx: int = 0,
|
101 |
+
) -> None:
|
102 |
+
return self.process_batch(trainer, pl_module, samples, batch)
|
103 |
+
|
104 |
+
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
105 |
+
self.save_end()
|
src/callbacks/simple_ema.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import threading
|
6 |
+
import lightning.pytorch as pl
|
7 |
+
from lightning.pytorch import Callback
|
8 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
9 |
+
|
10 |
+
from src.utils.copy import swap_tensors
|
11 |
+
|
12 |
+
class SimpleEMA(Callback):
|
13 |
+
def __init__(self, net:nn.Module, ema_net:nn.Module,
|
14 |
+
decay: float = 0.9999,
|
15 |
+
every_n_steps: int = 1,
|
16 |
+
eval_original_model:bool = False
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.decay = decay
|
20 |
+
self.every_n_steps = every_n_steps
|
21 |
+
self.eval_original_model = eval_original_model
|
22 |
+
self._stream = torch.cuda.Stream()
|
23 |
+
|
24 |
+
self.net_params = list(net.parameters())
|
25 |
+
self.ema_params = list(ema_net.parameters())
|
26 |
+
|
27 |
+
def swap_model(self):
|
28 |
+
for ema_p, p, in zip(self.ema_params, self.net_params):
|
29 |
+
swap_tensors(ema_p, p)
|
30 |
+
|
31 |
+
def ema_step(self):
|
32 |
+
@torch.no_grad()
|
33 |
+
def ema_update(ema_model_tuple, current_model_tuple, decay):
|
34 |
+
torch._foreach_mul_(ema_model_tuple, decay)
|
35 |
+
torch._foreach_add_(
|
36 |
+
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
|
37 |
+
)
|
38 |
+
|
39 |
+
if self._stream is not None:
|
40 |
+
self._stream.wait_stream(torch.cuda.current_stream())
|
41 |
+
with torch.cuda.stream(self._stream):
|
42 |
+
ema_update(self.ema_params, self.net_params, self.decay)
|
43 |
+
|
44 |
+
|
45 |
+
def on_train_batch_end(
|
46 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
|
47 |
+
) -> None:
|
48 |
+
if trainer.global_step % self.every_n_steps == 0:
|
49 |
+
self.ema_step()
|
50 |
+
|
51 |
+
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
52 |
+
if not self.eval_original_model:
|
53 |
+
self.swap_model()
|
54 |
+
|
55 |
+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
56 |
+
if not self.eval_original_model:
|
57 |
+
self.swap_model()
|
58 |
+
|
59 |
+
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
60 |
+
if not self.eval_original_model:
|
61 |
+
self.swap_model()
|
62 |
+
|
63 |
+
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
64 |
+
if not self.eval_original_model:
|
65 |
+
self.swap_model()
|
66 |
+
|
67 |
+
|
68 |
+
def state_dict(self) -> Dict[str, Any]:
|
69 |
+
return {
|
70 |
+
"decay": self.decay,
|
71 |
+
"every_n_steps": self.every_n_steps,
|
72 |
+
"eval_original_model": self.eval_original_model,
|
73 |
+
}
|
74 |
+
|
75 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
76 |
+
self.decay = state_dict["decay"]
|
77 |
+
self.every_n_steps = state_dict["every_n_steps"]
|
78 |
+
self.eval_original_model = state_dict["eval_original_model"]
|
79 |
+
|
src/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/data/dataset/__init__.py
ADDED
File without changes
|
src/data/dataset/celeba.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
from torchvision.datasets import CelebA
|
3 |
+
|
4 |
+
|
5 |
+
class LocalDataset(CelebA):
|
6 |
+
def __init__(self, root:str, ):
|
7 |
+
super(LocalDataset, self).__init__(root, "train")
|
8 |
+
|
9 |
+
def __getitem__(self, idx):
|
10 |
+
data = super().__getitem__(idx)
|
11 |
+
return data
|
src/data/dataset/imagenet.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision.datasets import ImageFolder
|
4 |
+
from torchvision.transforms.functional import to_tensor
|
5 |
+
from torchvision.transforms import Normalize
|
6 |
+
|
7 |
+
from src.data.dataset.metric_dataset import CenterCrop
|
8 |
+
|
9 |
+
class LocalCachedDataset(ImageFolder):
|
10 |
+
def __init__(self, root, resolution=256):
|
11 |
+
super().__init__(root)
|
12 |
+
self.transform = CenterCrop(resolution)
|
13 |
+
self.cache_root = None
|
14 |
+
|
15 |
+
def load_latent(self, latent_path):
|
16 |
+
pk_data = torch.load(latent_path)
|
17 |
+
mean = pk_data['mean'].to(torch.float32)
|
18 |
+
logvar = pk_data['logvar'].to(torch.float32)
|
19 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
20 |
+
std = torch.exp(0.5 * logvar)
|
21 |
+
latent = mean + torch.randn_like(mean) * std
|
22 |
+
return latent
|
23 |
+
|
24 |
+
def __getitem__(self, idx: int):
|
25 |
+
image_path, target = self.samples[idx]
|
26 |
+
latent_path = image_path.replace(self.root, self.cache_root) + ".pt"
|
27 |
+
|
28 |
+
raw_image = Image.open(image_path).convert('RGB')
|
29 |
+
raw_image = self.transform(raw_image)
|
30 |
+
raw_image = to_tensor(raw_image)
|
31 |
+
if self.cache_root is not None:
|
32 |
+
latent = self.load_latent(latent_path)
|
33 |
+
else:
|
34 |
+
latent = raw_image
|
35 |
+
return raw_image, latent, target
|
36 |
+
|
37 |
+
class ImageNet256(LocalCachedDataset):
|
38 |
+
def __init__(self, root, ):
|
39 |
+
super().__init__(root, 256)
|
40 |
+
self.cache_root = root + "_256_latent"
|
41 |
+
|
42 |
+
class ImageNet512(LocalCachedDataset):
|
43 |
+
def __init__(self, root, ):
|
44 |
+
super().__init__(root, 512)
|
45 |
+
self.cache_root = root + "_512_latent"
|
46 |
+
|
47 |
+
class PixImageNet(ImageFolder):
|
48 |
+
def __init__(self, root, resolution=256):
|
49 |
+
super().__init__(root)
|
50 |
+
self.transform = CenterCrop(resolution)
|
51 |
+
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
52 |
+
|
53 |
+
def __getitem__(self, idx: int):
|
54 |
+
image_path, target = self.samples[idx]
|
55 |
+
raw_image = Image.open(image_path).convert('RGB')
|
56 |
+
raw_image = self.transform(raw_image)
|
57 |
+
raw_image = to_tensor(raw_image)
|
58 |
+
|
59 |
+
normalized_image = self.normalize(raw_image)
|
60 |
+
return raw_image, normalized_image, target
|
61 |
+
|
62 |
+
class PixImageNet64(PixImageNet):
|
63 |
+
def __init__(self, root, ):
|
64 |
+
super().__init__(root, 64)
|
65 |
+
|
66 |
+
class PixImageNet128(PixImageNet):
|
67 |
+
def __init__(self, root, ):
|
68 |
+
super().__init__(root, 128)
|
69 |
+
|
70 |
+
|
71 |
+
class PixImageNet256(PixImageNet):
|
72 |
+
def __init__(self, root, ):
|
73 |
+
super().__init__(root, 256)
|
74 |
+
|
75 |
+
class PixImageNet512(PixImageNet):
|
76 |
+
def __init__(self, root, ):
|
77 |
+
super().__init__(root, 512)
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
src/data/dataset/metric_dataset.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from torchvision.io.image import read_image
|
7 |
+
import torchvision.transforms as tvtf
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
class CenterCrop:
|
11 |
+
def __init__(self, size):
|
12 |
+
self.size = size
|
13 |
+
def __call__(self, image):
|
14 |
+
def center_crop_arr(pil_image, image_size):
|
15 |
+
"""
|
16 |
+
Center cropping implementation from ADM.
|
17 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
18 |
+
"""
|
19 |
+
while min(*pil_image.size) >= 2 * image_size:
|
20 |
+
pil_image = pil_image.resize(
|
21 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
22 |
+
)
|
23 |
+
|
24 |
+
scale = image_size / min(*pil_image.size)
|
25 |
+
pil_image = pil_image.resize(
|
26 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
27 |
+
)
|
28 |
+
|
29 |
+
arr = np.array(pil_image)
|
30 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
31 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
32 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
33 |
+
|
34 |
+
return center_crop_arr(image, self.size)
|
35 |
+
|
36 |
+
|
37 |
+
from PIL import Image
|
38 |
+
IMG_EXTENSIONS = (
|
39 |
+
"*.png",
|
40 |
+
"*.JPEG",
|
41 |
+
"*.jpeg",
|
42 |
+
"*.jpg"
|
43 |
+
)
|
44 |
+
|
45 |
+
def test_collate(batch):
|
46 |
+
return torch.stack(batch)
|
47 |
+
|
48 |
+
class ImageDataset(Dataset):
|
49 |
+
def __init__(self, root, image_size=(224, 224)):
|
50 |
+
self.root = pathlib.Path(root)
|
51 |
+
images = []
|
52 |
+
for ext in IMG_EXTENSIONS:
|
53 |
+
images.extend(self.root.rglob(ext))
|
54 |
+
random.shuffle(images)
|
55 |
+
self.images = list(map(lambda x: str(x), images))
|
56 |
+
self.transform = tvtf.Compose(
|
57 |
+
[
|
58 |
+
CenterCrop(image_size[0]),
|
59 |
+
tvtf.ToTensor(),
|
60 |
+
tvtf.Lambda(lambda x: (x*255).to(torch.uint8)),
|
61 |
+
tvtf.Lambda(lambda x: x.expand(3, -1, -1))
|
62 |
+
]
|
63 |
+
)
|
64 |
+
self.size = image_size
|
65 |
+
|
66 |
+
def __getitem__(self, idx):
|
67 |
+
try:
|
68 |
+
image = Image.open(self.images[idx])
|
69 |
+
image = self.transform(image)
|
70 |
+
except Exception as e:
|
71 |
+
print(self.images[idx])
|
72 |
+
image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8)
|
73 |
+
|
74 |
+
# print(image)
|
75 |
+
metadata = dict(
|
76 |
+
path = self.images[idx],
|
77 |
+
root = self.root,
|
78 |
+
)
|
79 |
+
return image #, metadata
|
80 |
+
|
81 |
+
def __len__(self):
|
82 |
+
return len(self.images)
|
src/data/dataset/randn.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class RandomNDataset(Dataset):
|
10 |
+
def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ):
|
11 |
+
self.selected_classes = selected_classes
|
12 |
+
if selected_classes is not None:
|
13 |
+
num_classes = len(selected_classes)
|
14 |
+
max_num_instances = 10*num_classes
|
15 |
+
self.num_classes = num_classes
|
16 |
+
self.seeds = seeds
|
17 |
+
if seeds is not None:
|
18 |
+
self.max_num_instances = len(seeds)*num_classes
|
19 |
+
self.num_seeds = len(seeds)
|
20 |
+
else:
|
21 |
+
self.num_seeds = (max_num_instances + num_classes - 1) // num_classes
|
22 |
+
self.max_num_instances = self.num_seeds*num_classes
|
23 |
+
|
24 |
+
self.latent_shape = latent_shape
|
25 |
+
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
label = idx // self.num_seeds
|
29 |
+
if self.selected_classes:
|
30 |
+
label = self.selected_classes[label]
|
31 |
+
seed = random.randint(0, 1<<31) #idx % self.num_seeds
|
32 |
+
if self.seeds is not None:
|
33 |
+
seed = self.seeds[idx % self.num_seeds]
|
34 |
+
|
35 |
+
# cls_dir = os.path.join(self.root, f"{label}")
|
36 |
+
filename = f"{label}_{seed}.png",
|
37 |
+
generator = torch.Generator().manual_seed(seed)
|
38 |
+
latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
|
39 |
+
return latent, label, filename
|
40 |
+
def __len__(self):
|
41 |
+
return self.max_num_instances
|
src/data/var_training.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Callable
|
3 |
+
from src.diffusion.base.training import *
|
4 |
+
from src.diffusion.base.scheduling import BaseScheduler
|
5 |
+
import concurrent.futures
|
6 |
+
from concurrent.futures import ProcessPoolExecutor
|
7 |
+
from typing import List
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
import numpy as np
|
12 |
+
import copy
|
13 |
+
import torchvision.transforms.functional as tvtf
|
14 |
+
from src.models.vae import uint82fp
|
15 |
+
|
16 |
+
|
17 |
+
def center_crop_arr(pil_image, width, height):
|
18 |
+
"""
|
19 |
+
Center cropping implementation from ADM.
|
20 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
21 |
+
"""
|
22 |
+
while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height:
|
23 |
+
pil_image = pil_image.resize(
|
24 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
25 |
+
)
|
26 |
+
|
27 |
+
scale = max(width / pil_image.size[0], height / pil_image.size[1])
|
28 |
+
pil_image = pil_image.resize(
|
29 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
30 |
+
)
|
31 |
+
arr = np.array(pil_image)
|
32 |
+
crop_y = random.randint(0, (arr.shape[0] - height))
|
33 |
+
crop_x = random.randint(0, (arr.shape[1] - width))
|
34 |
+
return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width])
|
35 |
+
|
36 |
+
def process_fn(width, height, data, hflip=0.5):
|
37 |
+
image, label = data
|
38 |
+
if random.uniform(0, 1) > hflip: # hflip
|
39 |
+
image = tvtf.hflip(image)
|
40 |
+
image = center_crop_arr(image, width, height) # crop
|
41 |
+
image = np.array(image).transpose(2, 0, 1)
|
42 |
+
return image, label
|
43 |
+
|
44 |
+
class VARCandidate:
|
45 |
+
def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024):
|
46 |
+
self.aspect_ratio = aspect_ratio
|
47 |
+
self.width = int(width)
|
48 |
+
self.height = int(height)
|
49 |
+
self.buffer = buffer
|
50 |
+
self.max_buffer_size = max_buffer_size
|
51 |
+
|
52 |
+
def add_sample(self, data):
|
53 |
+
self.buffer.append(data)
|
54 |
+
self.buffer = self.buffer[-self.max_buffer_size:]
|
55 |
+
|
56 |
+
def ready(self, batch_size):
|
57 |
+
return len(self.buffer) >= batch_size
|
58 |
+
|
59 |
+
def get_batch(self, batch_size):
|
60 |
+
batch = self.buffer[:batch_size]
|
61 |
+
self.buffer = self.buffer[batch_size:]
|
62 |
+
batch = [copy.deepcopy(b.result()) for b in batch]
|
63 |
+
x, y = zip(*batch)
|
64 |
+
x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0)
|
65 |
+
x = list(map(uint82fp, x))
|
66 |
+
return x, y
|
67 |
+
|
68 |
+
class VARTransformEngine:
|
69 |
+
def __init__(self,
|
70 |
+
base_image_size,
|
71 |
+
num_aspect_ratios,
|
72 |
+
min_aspect_ratio,
|
73 |
+
max_aspect_ratio,
|
74 |
+
num_workers = 8,
|
75 |
+
):
|
76 |
+
self.base_image_size = base_image_size
|
77 |
+
self.num_aspect_ratios = num_aspect_ratios
|
78 |
+
self.min_aspect_ratio = min_aspect_ratio
|
79 |
+
self.max_aspect_ratio = max_aspect_ratio
|
80 |
+
self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios)
|
81 |
+
self.aspect_ratios = self.aspect_ratios.tolist()
|
82 |
+
self.candidates_pool = []
|
83 |
+
for i in range(self.num_aspect_ratios):
|
84 |
+
candidate = VARCandidate(
|
85 |
+
aspect_ratio=self.aspect_ratios[i],
|
86 |
+
width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16),
|
87 |
+
height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16),
|
88 |
+
buffer=[],
|
89 |
+
max_buffer_size=1024
|
90 |
+
)
|
91 |
+
self.candidates_pool.append(candidate)
|
92 |
+
self.default_candidate = VARCandidate(
|
93 |
+
aspect_ratio=1.0,
|
94 |
+
width=self.base_image_size,
|
95 |
+
height=self.base_image_size,
|
96 |
+
buffer=[],
|
97 |
+
max_buffer_size=1024,
|
98 |
+
)
|
99 |
+
self.executor_pool = ProcessPoolExecutor(max_workers=num_workers)
|
100 |
+
self._prefill_count = 100
|
101 |
+
|
102 |
+
def find_candidate(self, data):
|
103 |
+
image = data[0]
|
104 |
+
aspect_ratio = image.size[0] / image.size[1]
|
105 |
+
min_distance = 1000000
|
106 |
+
min_candidate = None
|
107 |
+
for candidate in self.candidates_pool:
|
108 |
+
dis = abs(aspect_ratio - candidate.aspect_ratio)
|
109 |
+
if dis < min_distance:
|
110 |
+
min_distance = dis
|
111 |
+
min_candidate = candidate
|
112 |
+
return min_candidate
|
113 |
+
|
114 |
+
|
115 |
+
def __call__(self, batch_data):
|
116 |
+
self._prefill_count -= 1
|
117 |
+
if isinstance(batch_data[0], torch.Tensor):
|
118 |
+
batch_data[0] = batch_data[0].unbind(0)
|
119 |
+
|
120 |
+
batch_data = list(zip(*batch_data))
|
121 |
+
for data in batch_data:
|
122 |
+
candidate = self.find_candidate(data)
|
123 |
+
future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data)
|
124 |
+
candidate.add_sample(future)
|
125 |
+
if self._prefill_count >= 0:
|
126 |
+
future = self.executor_pool.submit(process_fn,
|
127 |
+
self.default_candidate.width,
|
128 |
+
self.default_candidate.height,
|
129 |
+
data)
|
130 |
+
self.default_candidate.add_sample(future)
|
131 |
+
|
132 |
+
batch_size = len(batch_data)
|
133 |
+
random.shuffle(self.candidates_pool)
|
134 |
+
for candidate in self.candidates_pool:
|
135 |
+
if candidate.ready(batch_size=batch_size):
|
136 |
+
return candidate.get_batch(batch_size=batch_size)
|
137 |
+
|
138 |
+
# fallback to default 256
|
139 |
+
for data in batch_data:
|
140 |
+
future = self.executor_pool.submit(process_fn,
|
141 |
+
self.default_candidate.width,
|
142 |
+
self.default_candidate.height,
|
143 |
+
data)
|
144 |
+
self.default_candidate.add_sample(future)
|
145 |
+
return self.default_candidate.get_batch(batch_size=batch_size)
|
src/diffusion/__init__.py
ADDED
File without changes
|
src/diffusion/base/guidance.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def simple_guidance_fn(out, cfg):
|
4 |
+
uncondition, condtion = out.chunk(2, dim=0)
|
5 |
+
out = uncondition + cfg * (condtion - uncondition)
|
6 |
+
return out
|
7 |
+
|
8 |
+
def c3_guidance_fn(out, cfg):
|
9 |
+
# guidance function in DiT/SiT, seems like a bug not a feature?
|
10 |
+
uncondition, condtion = out.chunk(2, dim=0)
|
11 |
+
out = condtion
|
12 |
+
out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3])
|
13 |
+
return out
|
14 |
+
|
15 |
+
def c4_guidance_fn(out, cfg):
|
16 |
+
# guidance function in DiT/SiT, seems like a bug not a feature?
|
17 |
+
uncondition, condition = out.chunk(2, dim=0)
|
18 |
+
out = condition
|
19 |
+
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
20 |
+
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
|
21 |
+
return out
|
22 |
+
|
23 |
+
def c4_p05_guidance_fn(out, cfg):
|
24 |
+
# guidance function in DiT/SiT, seems like a bug not a feature?
|
25 |
+
uncondition, condition = out.chunk(2, dim=0)
|
26 |
+
out = condition
|
27 |
+
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
28 |
+
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
|
29 |
+
return out
|
30 |
+
|
31 |
+
def c4_p10_guidance_fn(out, cfg):
|
32 |
+
# guidance function in DiT/SiT, seems like a bug not a feature?
|
33 |
+
uncondition, condition = out.chunk(2, dim=0)
|
34 |
+
out = condition
|
35 |
+
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
36 |
+
out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:])
|
37 |
+
return out
|
38 |
+
|
39 |
+
def c4_p15_guidance_fn(out, cfg):
|
40 |
+
# guidance function in DiT/SiT, seems like a bug not a feature?
|
41 |
+
uncondition, condition = out.chunk(2, dim=0)
|
42 |
+
out = condition
|
43 |
+
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
44 |
+
out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:])
|
45 |
+
return out
|
46 |
+
|
47 |
+
def c4_p20_guidance_fn(out, cfg):
|
48 |
+
# guidance function in DiT/SiT, seems like a bug not a feature?
|
49 |
+
uncondition, condition = out.chunk(2, dim=0)
|
50 |
+
out = condition
|
51 |
+
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
52 |
+
out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:])
|
53 |
+
return out
|
54 |
+
|
55 |
+
def p4_guidance_fn(out, cfg):
|
56 |
+
# guidance function in DiT/SiT, seems like a bug not a feature?
|
57 |
+
uncondition, condtion = out.chunk(2, dim=0)
|
58 |
+
out = condtion
|
59 |
+
out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:])
|
60 |
+
return out
|
src/diffusion/base/sampling.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from typing import Callable
|
6 |
+
from src.diffusion.base.scheduling import BaseScheduler
|
7 |
+
|
8 |
+
class BaseSampler(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
scheduler: BaseScheduler = None,
|
11 |
+
guidance_fn: Callable = None,
|
12 |
+
num_steps: int = 250,
|
13 |
+
guidance: Union[float, List[float]] = 1.0,
|
14 |
+
*args,
|
15 |
+
**kwargs
|
16 |
+
):
|
17 |
+
super(BaseSampler, self).__init__()
|
18 |
+
self.num_steps = num_steps
|
19 |
+
self.guidance = guidance
|
20 |
+
self.guidance_fn = guidance_fn
|
21 |
+
self.scheduler = scheduler
|
22 |
+
|
23 |
+
|
24 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
25 |
+
raise NotImplementedError
|
26 |
+
|
27 |
+
def __call__(self, net, noise, condition, uncondition):
|
28 |
+
denoised = self._impl_sampling(net, noise, condition, uncondition)
|
29 |
+
return denoised
|
30 |
+
|
31 |
+
|
src/diffusion/base/scheduling.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
|
4 |
+
class BaseScheduler:
|
5 |
+
def alpha(self, t) -> Tensor:
|
6 |
+
...
|
7 |
+
def sigma(self, t) -> Tensor:
|
8 |
+
...
|
9 |
+
|
10 |
+
def dalpha(self, t) -> Tensor:
|
11 |
+
...
|
12 |
+
def dsigma(self, t) -> Tensor:
|
13 |
+
...
|
14 |
+
|
15 |
+
def dalpha_over_alpha(self, t) -> Tensor:
|
16 |
+
return self.dalpha(t) / self.alpha(t)
|
17 |
+
|
18 |
+
def dsigma_mul_sigma(self, t) -> Tensor:
|
19 |
+
return self.dsigma(t)*self.sigma(t)
|
20 |
+
|
21 |
+
def drift_coefficient(self, t):
|
22 |
+
alpha, sigma = self.alpha(t), self.sigma(t)
|
23 |
+
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
|
24 |
+
return dalpha/(alpha + 1e-6)
|
25 |
+
|
26 |
+
def diffuse_coefficient(self, t):
|
27 |
+
alpha, sigma = self.alpha(t), self.sigma(t)
|
28 |
+
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
|
29 |
+
return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2
|
30 |
+
|
31 |
+
def w(self, t):
|
32 |
+
return self.sigma(t)
|
src/diffusion/base/training.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
class BaseTrainer(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
null_condition_p=0.1,
|
9 |
+
log_var=False,
|
10 |
+
):
|
11 |
+
super(BaseTrainer, self).__init__()
|
12 |
+
self.null_condition_p = null_condition_p
|
13 |
+
self.log_var = log_var
|
14 |
+
|
15 |
+
def preproprocess(self, raw_iamges, x, condition, uncondition):
|
16 |
+
bsz = x.shape[0]
|
17 |
+
if self.null_condition_p > 0:
|
18 |
+
mask = torch.rand((bsz), device=condition.device) < self.null_condition_p
|
19 |
+
mask = mask.expand_as(condition)
|
20 |
+
condition[mask] = uncondition[mask]
|
21 |
+
return raw_iamges, x, condition
|
22 |
+
|
23 |
+
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def __call__(self, net, ema_net, raw_images, x, condition, uncondition):
|
27 |
+
raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition)
|
28 |
+
return self._impl_trainstep(net, ema_net, raw_images, x, condition)
|
29 |
+
|
src/diffusion/ddpm/ddim_sampling.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from src.diffusion.base.scheduling import *
|
3 |
+
from src.diffusion.base.sampling import *
|
4 |
+
|
5 |
+
from typing import Callable
|
6 |
+
|
7 |
+
import logging
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class DDIMSampler(BaseSampler):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
train_num_steps=1000,
|
14 |
+
*args,
|
15 |
+
**kwargs
|
16 |
+
):
|
17 |
+
super().__init__(*args, **kwargs)
|
18 |
+
self.train_num_steps = train_num_steps
|
19 |
+
assert self.scheduler is not None
|
20 |
+
|
21 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
22 |
+
batch_size = noise.shape[0]
|
23 |
+
steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device)
|
24 |
+
steps = torch.flip(steps, dims=[0])
|
25 |
+
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
26 |
+
x = x0 = noise
|
27 |
+
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
28 |
+
t_cur = t_cur.repeat(batch_size)
|
29 |
+
t_next = t_next.repeat(batch_size)
|
30 |
+
sigma = self.scheduler.sigma(t_cur)
|
31 |
+
alpha = self.scheduler.alpha(t_cur)
|
32 |
+
sigma_next = self.scheduler.sigma(t_next)
|
33 |
+
alpha_next = self.scheduler.alpha(t_next)
|
34 |
+
cfg_x = torch.cat([x, x], dim=0)
|
35 |
+
t = t_cur.repeat(2)
|
36 |
+
out = net(cfg_x, t, cfg_condition)
|
37 |
+
out = self.guidance_fn(out, self.guidance)
|
38 |
+
x0 = (x - sigma * out) / alpha
|
39 |
+
x = alpha_next * x0 + sigma_next * out
|
40 |
+
return x0
|
src/diffusion/ddpm/scheduling.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from src.diffusion.base.scheduling import *
|
4 |
+
|
5 |
+
|
6 |
+
class DDPMScheduler(BaseScheduler):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
beta_min=0.0001,
|
10 |
+
beta_max=0.02,
|
11 |
+
num_steps=1000,
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.beta_min = beta_min
|
15 |
+
self.beta_max = beta_max
|
16 |
+
self.num_steps = num_steps
|
17 |
+
|
18 |
+
self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
|
19 |
+
self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
|
20 |
+
self.sigmas_table = 1-self.alphas_table
|
21 |
+
|
22 |
+
|
23 |
+
def beta(self, t) -> Tensor:
|
24 |
+
t = t.to(torch.long)
|
25 |
+
return self.betas_table[t].view(-1, 1, 1, 1)
|
26 |
+
|
27 |
+
def alpha(self, t) -> Tensor:
|
28 |
+
t = t.to(torch.long)
|
29 |
+
return self.alphas_table[t].view(-1, 1, 1, 1)**0.5
|
30 |
+
|
31 |
+
def sigma(self, t) -> Tensor:
|
32 |
+
t = t.to(torch.long)
|
33 |
+
return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5
|
34 |
+
|
35 |
+
def dsigma(self, t) -> Tensor:
|
36 |
+
raise NotImplementedError("wrong usage")
|
37 |
+
|
38 |
+
def dalpha_over_alpha(self, t) ->Tensor:
|
39 |
+
raise NotImplementedError("wrong usage")
|
40 |
+
|
41 |
+
def dsigma_mul_sigma(self, t) ->Tensor:
|
42 |
+
raise NotImplementedError("wrong usage")
|
43 |
+
|
44 |
+
def dalpha(self, t) -> Tensor:
|
45 |
+
raise NotImplementedError("wrong usage")
|
46 |
+
|
47 |
+
def drift_coefficient(self, t):
|
48 |
+
raise NotImplementedError("wrong usage")
|
49 |
+
|
50 |
+
def diffuse_coefficient(self, t):
|
51 |
+
raise NotImplementedError("wrong usage")
|
52 |
+
|
53 |
+
def w(self, t):
|
54 |
+
raise NotImplementedError("wrong usage")
|
55 |
+
|
56 |
+
|
57 |
+
class VPScheduler(BaseScheduler):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
beta_min=0.1,
|
61 |
+
beta_max=20,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.beta_min = beta_min
|
65 |
+
self.beta_d = beta_max - beta_min
|
66 |
+
def beta(self, t) -> Tensor:
|
67 |
+
t = torch.clamp(t, min=1e-3, max=1)
|
68 |
+
return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)
|
69 |
+
|
70 |
+
def sigma(self, t) -> Tensor:
|
71 |
+
t = torch.clamp(t, min=1e-3, max=1)
|
72 |
+
inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
|
73 |
+
return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)
|
74 |
+
|
75 |
+
def dsigma(self, t) -> Tensor:
|
76 |
+
raise NotImplementedError("wrong usage")
|
77 |
+
|
78 |
+
def dalpha_over_alpha(self, t) ->Tensor:
|
79 |
+
raise NotImplementedError("wrong usage")
|
80 |
+
|
81 |
+
def dsigma_mul_sigma(self, t) ->Tensor:
|
82 |
+
raise NotImplementedError("wrong usage")
|
83 |
+
|
84 |
+
def dalpha(self, t) -> Tensor:
|
85 |
+
raise NotImplementedError("wrong usage")
|
86 |
+
|
87 |
+
def alpha(self, t) -> Tensor:
|
88 |
+
t = torch.clamp(t, min=1e-3, max=1)
|
89 |
+
inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
|
90 |
+
return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)
|
91 |
+
|
92 |
+
def drift_coefficient(self, t):
|
93 |
+
raise NotImplementedError("wrong usage")
|
94 |
+
|
95 |
+
def diffuse_coefficient(self, t):
|
96 |
+
raise NotImplementedError("wrong usage")
|
97 |
+
|
98 |
+
def w(self, t):
|
99 |
+
return self.diffuse_coefficient(t)
|
100 |
+
|
101 |
+
|
102 |
+
|
src/diffusion/ddpm/training.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Callable
|
3 |
+
from src.diffusion.base.training import *
|
4 |
+
from src.diffusion.base.scheduling import BaseScheduler
|
5 |
+
|
6 |
+
def inverse_sigma(alpha, sigma):
|
7 |
+
return 1/sigma**2
|
8 |
+
def snr(alpha, sigma):
|
9 |
+
return alpha/sigma
|
10 |
+
def minsnr(alpha, sigma, threshold=5):
|
11 |
+
return torch.clip(alpha/sigma, min=threshold)
|
12 |
+
def maxsnr(alpha, sigma, threshold=5):
|
13 |
+
return torch.clip(alpha/sigma, max=threshold)
|
14 |
+
def constant(alpha, sigma):
|
15 |
+
return 1
|
16 |
+
|
17 |
+
class VPTrainer(BaseTrainer):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
scheduler: BaseScheduler,
|
21 |
+
loss_weight_fn:Callable=constant,
|
22 |
+
train_max_t=1000,
|
23 |
+
lognorm_t=False,
|
24 |
+
*args,
|
25 |
+
**kwargs
|
26 |
+
):
|
27 |
+
super().__init__(*args, **kwargs)
|
28 |
+
self.lognorm_t = lognorm_t
|
29 |
+
self.scheduler = scheduler
|
30 |
+
self.loss_weight_fn = loss_weight_fn
|
31 |
+
self.train_max_t = train_max_t
|
32 |
+
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
33 |
+
batch_size = x.shape[0]
|
34 |
+
if self.lognorm_t:
|
35 |
+
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
36 |
+
else:
|
37 |
+
t = torch.rand(batch_size).to(x.device, x.dtype)
|
38 |
+
|
39 |
+
noise = torch.randn_like(x)
|
40 |
+
alpha = self.scheduler.alpha(t)
|
41 |
+
sigma = self.scheduler.sigma(t)
|
42 |
+
x_t = alpha * x + noise * sigma
|
43 |
+
out = net(x_t, t*self.train_max_t, y)
|
44 |
+
weight = self.loss_weight_fn(alpha, sigma)
|
45 |
+
loss = weight*(out - noise)**2
|
46 |
+
|
47 |
+
out = dict(
|
48 |
+
loss=loss.mean(),
|
49 |
+
)
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
class DDPMTrainer(BaseTrainer):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
scheduler: BaseScheduler,
|
57 |
+
loss_weight_fn: Callable = constant,
|
58 |
+
train_max_t=1000,
|
59 |
+
lognorm_t=False,
|
60 |
+
*args,
|
61 |
+
**kwargs
|
62 |
+
):
|
63 |
+
super().__init__(*args, **kwargs)
|
64 |
+
self.lognorm_t = lognorm_t
|
65 |
+
self.scheduler = scheduler
|
66 |
+
self.loss_weight_fn = loss_weight_fn
|
67 |
+
self.train_max_t = train_max_t
|
68 |
+
|
69 |
+
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
70 |
+
batch_size = x.shape[0]
|
71 |
+
t = torch.randint(0, self.train_max_t, (batch_size,))
|
72 |
+
noise = torch.randn_like(x)
|
73 |
+
alpha = self.scheduler.alpha(t)
|
74 |
+
sigma = self.scheduler.sigma(t)
|
75 |
+
x_t = alpha * x + noise * sigma
|
76 |
+
out = net(x_t, t, y)
|
77 |
+
weight = self.loss_weight_fn(alpha, sigma)
|
78 |
+
loss = weight * (out - noise) ** 2
|
79 |
+
|
80 |
+
out = dict(
|
81 |
+
loss=loss.mean(),
|
82 |
+
)
|
83 |
+
return out
|
src/diffusion/ddpm/vp_sampling.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from src.diffusion.base.scheduling import *
|
4 |
+
from src.diffusion.base.sampling import *
|
5 |
+
from typing import Callable
|
6 |
+
|
7 |
+
def ode_step_fn(x, eps, beta, sigma, dt):
|
8 |
+
return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt
|
9 |
+
|
10 |
+
def sde_step_fn(x, eps, beta, sigma, dt):
|
11 |
+
return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x)
|
12 |
+
|
13 |
+
import logging
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class VPEulerSampler(BaseSampler):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
train_max_t=1000,
|
20 |
+
guidance_fn: Callable = None,
|
21 |
+
step_fn: Callable = ode_step_fn,
|
22 |
+
last_step=None,
|
23 |
+
last_step_fn: Callable = ode_step_fn,
|
24 |
+
*args,
|
25 |
+
**kwargs
|
26 |
+
):
|
27 |
+
super().__init__(*args, **kwargs)
|
28 |
+
self.guidance_fn = guidance_fn
|
29 |
+
self.step_fn = step_fn
|
30 |
+
self.last_step = last_step
|
31 |
+
self.last_step_fn = last_step_fn
|
32 |
+
self.train_max_t = train_max_t
|
33 |
+
|
34 |
+
if self.last_step is None or self.num_steps == 1:
|
35 |
+
self.last_step = 1.0 / self.num_steps
|
36 |
+
assert self.last_step > 0.0
|
37 |
+
assert self.scheduler is not None
|
38 |
+
|
39 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
40 |
+
batch_size = noise.shape[0]
|
41 |
+
steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device)
|
42 |
+
steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0)
|
43 |
+
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
44 |
+
x = noise
|
45 |
+
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
46 |
+
dt = t_next - t_cur
|
47 |
+
t_cur = t_cur.repeat(batch_size)
|
48 |
+
sigma = self.scheduler.sigma(t_cur)
|
49 |
+
beta = self.scheduler.beta(t_cur)
|
50 |
+
cfg_x = torch.cat([x, x], dim=0)
|
51 |
+
cfg_t = t_cur.repeat(2)
|
52 |
+
out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition)
|
53 |
+
eps = self.guidance_fn(out, self.guidance)
|
54 |
+
if i < self.num_steps -1 :
|
55 |
+
x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0])
|
56 |
+
x = self.step_fn(x, eps, beta, sigma, dt)
|
57 |
+
else:
|
58 |
+
x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step)
|
59 |
+
return x
|
src/diffusion/flow_matching/adam_sampling.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from src.diffusion.base.sampling import *
|
3 |
+
from src.diffusion.base.scheduling import *
|
4 |
+
from src.diffusion.pre_integral import *
|
5 |
+
|
6 |
+
from typing import Callable, List, Tuple
|
7 |
+
|
8 |
+
def ode_step_fn(x, v, dt, s, w):
|
9 |
+
return x + v * dt
|
10 |
+
|
11 |
+
def t2snr(t):
|
12 |
+
if isinstance(t, torch.Tensor):
|
13 |
+
return (t.clip(min=1e-8)/(1-t + 1e-8))
|
14 |
+
if isinstance(t, List) or isinstance(t, Tuple):
|
15 |
+
return [t2snr(t) for t in t]
|
16 |
+
t = max(t, 1e-8)
|
17 |
+
return (t/(1-t + 1e-8))
|
18 |
+
|
19 |
+
def t2logsnr(t):
|
20 |
+
if isinstance(t, torch.Tensor):
|
21 |
+
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
|
22 |
+
if isinstance(t, List) or isinstance(t, Tuple):
|
23 |
+
return [t2logsnr(t) for t in t]
|
24 |
+
t = max(t, 1e-3)
|
25 |
+
return math.log(t/(1-t + 1e-3))
|
26 |
+
|
27 |
+
def t2isnr(t):
|
28 |
+
return 1/t2snr(t)
|
29 |
+
|
30 |
+
def nop(t):
|
31 |
+
return t
|
32 |
+
|
33 |
+
def shift_respace_fn(t, shift=3.0):
|
34 |
+
return t / (t + (1 - t) * shift)
|
35 |
+
|
36 |
+
import logging
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
class AdamLMSampler(BaseSampler):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
order: int = 2,
|
43 |
+
timeshift: float = 1.0,
|
44 |
+
lms_transform_fn: Callable = nop,
|
45 |
+
w_scheduler: BaseScheduler = None,
|
46 |
+
step_fn: Callable = ode_step_fn,
|
47 |
+
*args,
|
48 |
+
**kwargs
|
49 |
+
):
|
50 |
+
super().__init__(*args, **kwargs)
|
51 |
+
self.step_fn = step_fn
|
52 |
+
self.w_scheduler = w_scheduler
|
53 |
+
|
54 |
+
assert self.scheduler is not None
|
55 |
+
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
56 |
+
self.order = order
|
57 |
+
self.lms_transform_fn = lms_transform_fn
|
58 |
+
|
59 |
+
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
60 |
+
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
61 |
+
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
62 |
+
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
63 |
+
self._reparameterize_coeffs()
|
64 |
+
|
65 |
+
def _reparameterize_coeffs(self):
|
66 |
+
solver_coeffs = [[] for _ in range(self.num_steps)]
|
67 |
+
for i in range(0, self.num_steps):
|
68 |
+
pre_vs = [1.0, ]*(i+1)
|
69 |
+
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
|
70 |
+
int_t_start = self.lms_transform_fn(self.timesteps[i])
|
71 |
+
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
|
72 |
+
|
73 |
+
order_annealing = self.order #self.num_steps - i
|
74 |
+
order = min(self.order, i + 1, order_annealing)
|
75 |
+
|
76 |
+
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
|
77 |
+
solver_coeffs[i] = coeffs
|
78 |
+
self.solver_coeffs = solver_coeffs
|
79 |
+
|
80 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
81 |
+
"""
|
82 |
+
sampling process of Euler sampler
|
83 |
+
-
|
84 |
+
"""
|
85 |
+
batch_size = noise.shape[0]
|
86 |
+
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
87 |
+
x = x0 = noise
|
88 |
+
pred_trajectory = []
|
89 |
+
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
|
90 |
+
timedeltas = self.timedeltas
|
91 |
+
solver_coeffs = self.solver_coeffs
|
92 |
+
for i in range(self.num_steps):
|
93 |
+
cfg_x = torch.cat([x, x], dim=0)
|
94 |
+
cfg_t = t_cur.repeat(2)
|
95 |
+
out = net(cfg_x, cfg_t, cfg_condition)
|
96 |
+
out = self.guidance_fn(out, self.guidances[i])
|
97 |
+
pred_trajectory.append(out)
|
98 |
+
out = torch.zeros_like(out)
|
99 |
+
order = len(self.solver_coeffs[i])
|
100 |
+
for j in range(order):
|
101 |
+
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
|
102 |
+
v = out
|
103 |
+
dt = timedeltas[i]
|
104 |
+
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
|
105 |
+
x = self.step_fn(x, v, dt, s=0, w=0)
|
106 |
+
t_cur += dt
|
107 |
+
return x
|
src/diffusion/flow_matching/sampling.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from src.diffusion.base.guidance import *
|
4 |
+
from src.diffusion.base.scheduling import *
|
5 |
+
from src.diffusion.base.sampling import *
|
6 |
+
|
7 |
+
from typing import Callable
|
8 |
+
|
9 |
+
|
10 |
+
def shift_respace_fn(t, shift=3.0):
|
11 |
+
return t / (t + (1 - t) * shift)
|
12 |
+
|
13 |
+
def ode_step_fn(x, v, dt, s, w):
|
14 |
+
return x + v * dt
|
15 |
+
|
16 |
+
def sde_mean_step_fn(x, v, dt, s, w):
|
17 |
+
return x + v * dt + s * w * dt
|
18 |
+
|
19 |
+
def sde_step_fn(x, v, dt, s, w):
|
20 |
+
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
|
21 |
+
|
22 |
+
def sde_preserve_step_fn(x, v, dt, s, w):
|
23 |
+
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
|
24 |
+
|
25 |
+
|
26 |
+
import logging
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
class EulerSampler(BaseSampler):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
w_scheduler: BaseScheduler = None,
|
33 |
+
timeshift=1.0,
|
34 |
+
step_fn: Callable = ode_step_fn,
|
35 |
+
last_step=None,
|
36 |
+
last_step_fn: Callable = ode_step_fn,
|
37 |
+
*args,
|
38 |
+
**kwargs
|
39 |
+
):
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
self.step_fn = step_fn
|
42 |
+
self.last_step = last_step
|
43 |
+
self.last_step_fn = last_step_fn
|
44 |
+
self.w_scheduler = w_scheduler
|
45 |
+
self.timeshift = timeshift
|
46 |
+
|
47 |
+
if self.last_step is None or self.num_steps == 1:
|
48 |
+
self.last_step = 1.0 / self.num_steps
|
49 |
+
|
50 |
+
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
51 |
+
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
52 |
+
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
53 |
+
|
54 |
+
assert self.last_step > 0.0
|
55 |
+
assert self.scheduler is not None
|
56 |
+
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
57 |
+
if self.w_scheduler is not None:
|
58 |
+
if self.step_fn == ode_step_fn:
|
59 |
+
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
60 |
+
|
61 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
62 |
+
"""
|
63 |
+
sampling process of Euler sampler
|
64 |
+
-
|
65 |
+
"""
|
66 |
+
batch_size = noise.shape[0]
|
67 |
+
steps = self.timesteps.to(noise.device)
|
68 |
+
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
69 |
+
x = noise
|
70 |
+
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
71 |
+
dt = t_next - t_cur
|
72 |
+
t_cur = t_cur.repeat(batch_size)
|
73 |
+
sigma = self.scheduler.sigma(t_cur)
|
74 |
+
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
|
75 |
+
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
76 |
+
if self.w_scheduler:
|
77 |
+
w = self.w_scheduler.w(t_cur)
|
78 |
+
else:
|
79 |
+
w = 0.0
|
80 |
+
|
81 |
+
cfg_x = torch.cat([x, x], dim=0)
|
82 |
+
cfg_t = t_cur.repeat(2)
|
83 |
+
out = net(cfg_x, cfg_t, cfg_condition)
|
84 |
+
out = self.guidance_fn(out, self.guidance)
|
85 |
+
v = out
|
86 |
+
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
|
87 |
+
if i < self.num_steps -1 :
|
88 |
+
x = self.step_fn(x, v, dt, s=s, w=w)
|
89 |
+
else:
|
90 |
+
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class HeunSampler(BaseSampler):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
scheduler: BaseScheduler = None,
|
98 |
+
w_scheduler: BaseScheduler = None,
|
99 |
+
exact_henu=False,
|
100 |
+
timeshift=1.0,
|
101 |
+
step_fn: Callable = ode_step_fn,
|
102 |
+
last_step=None,
|
103 |
+
last_step_fn: Callable = ode_step_fn,
|
104 |
+
*args,
|
105 |
+
**kwargs
|
106 |
+
):
|
107 |
+
super().__init__(*args, **kwargs)
|
108 |
+
self.scheduler = scheduler
|
109 |
+
self.exact_henu = exact_henu
|
110 |
+
self.step_fn = step_fn
|
111 |
+
self.last_step = last_step
|
112 |
+
self.last_step_fn = last_step_fn
|
113 |
+
self.w_scheduler = w_scheduler
|
114 |
+
self.timeshift = timeshift
|
115 |
+
|
116 |
+
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
117 |
+
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
118 |
+
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
119 |
+
|
120 |
+
if self.last_step is None or self.num_steps == 1:
|
121 |
+
self.last_step = 1.0 / self.num_steps
|
122 |
+
assert self.last_step > 0.0
|
123 |
+
assert self.scheduler is not None
|
124 |
+
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
125 |
+
if self.w_scheduler is not None:
|
126 |
+
if self.step_fn == ode_step_fn:
|
127 |
+
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
128 |
+
|
129 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
130 |
+
"""
|
131 |
+
sampling process of Henu sampler
|
132 |
+
-
|
133 |
+
"""
|
134 |
+
batch_size = noise.shape[0]
|
135 |
+
steps = self.timesteps.to(noise.device)
|
136 |
+
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
137 |
+
x = noise
|
138 |
+
v_hat, s_hat = 0.0, 0.0
|
139 |
+
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
140 |
+
dt = t_next - t_cur
|
141 |
+
t_cur = t_cur.repeat(batch_size)
|
142 |
+
sigma = self.scheduler.sigma(t_cur)
|
143 |
+
alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur)
|
144 |
+
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
145 |
+
t_hat = t_next
|
146 |
+
t_hat = t_hat.repeat(batch_size)
|
147 |
+
sigma_hat = self.scheduler.sigma(t_hat)
|
148 |
+
alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat)
|
149 |
+
dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat)
|
150 |
+
|
151 |
+
if self.w_scheduler:
|
152 |
+
w = self.w_scheduler.w(t_cur)
|
153 |
+
else:
|
154 |
+
w = 0.0
|
155 |
+
if i == 0 or self.exact_henu:
|
156 |
+
cfg_x = torch.cat([x, x], dim=0)
|
157 |
+
cfg_t_cur = t_cur.repeat(2)
|
158 |
+
out = net(cfg_x, cfg_t_cur, cfg_condition)
|
159 |
+
out = self.guidance_fn(out, self.guidance)
|
160 |
+
v = out
|
161 |
+
s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma)
|
162 |
+
else:
|
163 |
+
v = v_hat
|
164 |
+
s = s_hat
|
165 |
+
x_hat = self.step_fn(x, v, dt, s=s, w=w)
|
166 |
+
# henu correct
|
167 |
+
if i < self.num_steps -1:
|
168 |
+
cfg_x_hat = torch.cat([x_hat, x_hat], dim=0)
|
169 |
+
cfg_t_hat = t_hat.repeat(2)
|
170 |
+
out = net(cfg_x_hat, cfg_t_hat, cfg_condition)
|
171 |
+
out = self.guidance_fn(out, self.guidance)
|
172 |
+
v_hat = out
|
173 |
+
s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat)
|
174 |
+
v = (v + v_hat) / 2
|
175 |
+
s = (s + s_hat) / 2
|
176 |
+
x = self.step_fn(x, v, dt, s=s, w=w)
|
177 |
+
else:
|
178 |
+
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
179 |
+
return x
|
src/diffusion/flow_matching/scheduling.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from src.diffusion.base.scheduling import *
|
4 |
+
|
5 |
+
|
6 |
+
class LinearScheduler(BaseScheduler):
|
7 |
+
def alpha(self, t) -> Tensor:
|
8 |
+
return (t).view(-1, 1, 1, 1)
|
9 |
+
def sigma(self, t) -> Tensor:
|
10 |
+
return (1-t).view(-1, 1, 1, 1)
|
11 |
+
def dalpha(self, t) -> Tensor:
|
12 |
+
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
|
13 |
+
def dsigma(self, t) -> Tensor:
|
14 |
+
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
|
15 |
+
|
16 |
+
# SoTA for ImageNet!
|
17 |
+
class GVPScheduler(BaseScheduler):
|
18 |
+
def alpha(self, t) -> Tensor:
|
19 |
+
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
20 |
+
def sigma(self, t) -> Tensor:
|
21 |
+
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
22 |
+
def dalpha(self, t) -> Tensor:
|
23 |
+
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
24 |
+
def dsigma(self, t) -> Tensor:
|
25 |
+
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
26 |
+
def w(self, t):
|
27 |
+
return torch.sin(t)**2
|
28 |
+
|
29 |
+
class ConstScheduler(BaseScheduler):
|
30 |
+
def w(self, t):
|
31 |
+
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
|
32 |
+
|
33 |
+
from src.diffusion.ddpm.scheduling import VPScheduler
|
34 |
+
class VPBetaScheduler(VPScheduler):
|
35 |
+
def w(self, t):
|
36 |
+
return self.beta(t).view(-1, 1, 1, 1)
|
37 |
+
|
38 |
+
|
39 |
+
|
src/diffusion/flow_matching/training.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Callable
|
3 |
+
from src.diffusion.base.training import *
|
4 |
+
from src.diffusion.base.scheduling import BaseScheduler
|
5 |
+
|
6 |
+
def inverse_sigma(alpha, sigma):
|
7 |
+
return 1/sigma**2
|
8 |
+
def snr(alpha, sigma):
|
9 |
+
return alpha/sigma
|
10 |
+
def minsnr(alpha, sigma, threshold=5):
|
11 |
+
return torch.clip(alpha/sigma, min=threshold)
|
12 |
+
def maxsnr(alpha, sigma, threshold=5):
|
13 |
+
return torch.clip(alpha/sigma, max=threshold)
|
14 |
+
def constant(alpha, sigma):
|
15 |
+
return 1
|
16 |
+
|
17 |
+
class FlowMatchingTrainer(BaseTrainer):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
scheduler: BaseScheduler,
|
21 |
+
loss_weight_fn:Callable=constant,
|
22 |
+
lognorm_t=False,
|
23 |
+
*args,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
super().__init__(*args, **kwargs)
|
27 |
+
self.lognorm_t = lognorm_t
|
28 |
+
self.scheduler = scheduler
|
29 |
+
self.loss_weight_fn = loss_weight_fn
|
30 |
+
|
31 |
+
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
32 |
+
batch_size = x.shape[0]
|
33 |
+
if self.lognorm_t:
|
34 |
+
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
35 |
+
else:
|
36 |
+
t = torch.rand(batch_size).to(x.device, x.dtype)
|
37 |
+
noise = torch.randn_like(x)
|
38 |
+
alpha = self.scheduler.alpha(t)
|
39 |
+
dalpha = self.scheduler.dalpha(t)
|
40 |
+
sigma = self.scheduler.sigma(t)
|
41 |
+
dsigma = self.scheduler.dsigma(t)
|
42 |
+
w = self.scheduler.w(t)
|
43 |
+
|
44 |
+
x_t = alpha * x + noise * sigma
|
45 |
+
v_t = dalpha * x + dsigma * noise
|
46 |
+
out = net(x_t, t, y)
|
47 |
+
|
48 |
+
weight = self.loss_weight_fn(alpha, sigma)
|
49 |
+
|
50 |
+
loss = weight*(out - v_t)**2
|
51 |
+
|
52 |
+
out = dict(
|
53 |
+
loss=loss.mean(),
|
54 |
+
)
|
55 |
+
return out
|
src/diffusion/flow_matching/training_cos.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Callable
|
3 |
+
from src.diffusion.base.training import *
|
4 |
+
from src.diffusion.base.scheduling import BaseScheduler
|
5 |
+
|
6 |
+
def inverse_sigma(alpha, sigma):
|
7 |
+
return 1/sigma**2
|
8 |
+
def snr(alpha, sigma):
|
9 |
+
return alpha/sigma
|
10 |
+
def minsnr(alpha, sigma, threshold=5):
|
11 |
+
return torch.clip(alpha/sigma, min=threshold)
|
12 |
+
def maxsnr(alpha, sigma, threshold=5):
|
13 |
+
return torch.clip(alpha/sigma, max=threshold)
|
14 |
+
def constant(alpha, sigma):
|
15 |
+
return 1
|
16 |
+
|
17 |
+
class COSTrainer(BaseTrainer):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
scheduler: BaseScheduler,
|
21 |
+
loss_weight_fn:Callable=constant,
|
22 |
+
lognorm_t=False,
|
23 |
+
*args,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
super().__init__(*args, **kwargs)
|
27 |
+
self.lognorm_t = lognorm_t
|
28 |
+
self.scheduler = scheduler
|
29 |
+
self.loss_weight_fn = loss_weight_fn
|
30 |
+
|
31 |
+
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
32 |
+
batch_size = x.shape[0]
|
33 |
+
if self.lognorm_t:
|
34 |
+
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
35 |
+
else:
|
36 |
+
t = torch.rand(batch_size).to(x.device, x.dtype)
|
37 |
+
noise = torch.randn_like(x)
|
38 |
+
alpha = self.scheduler.alpha(t)
|
39 |
+
dalpha = self.scheduler.dalpha(t)
|
40 |
+
sigma = self.scheduler.sigma(t)
|
41 |
+
dsigma = self.scheduler.dsigma(t)
|
42 |
+
w = self.scheduler.w(t)
|
43 |
+
|
44 |
+
x_t = alpha * x + noise * sigma
|
45 |
+
v_t = dalpha * x + dsigma * noise
|
46 |
+
out = net(x_t, t, y)
|
47 |
+
|
48 |
+
weight = self.loss_weight_fn(alpha, sigma)
|
49 |
+
|
50 |
+
fm_loss = weight*(out - v_t)**2
|
51 |
+
cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1)
|
52 |
+
cos_loss = 1 - cos_sim
|
53 |
+
|
54 |
+
out = dict(
|
55 |
+
fm_loss=fm_loss.mean(),
|
56 |
+
cos_loss=cos_loss.mean(),
|
57 |
+
loss=fm_loss.mean() + cos_loss.mean(),
|
58 |
+
)
|
59 |
+
return out
|
src/diffusion/flow_matching/training_repa.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import timm
|
4 |
+
from torch.nn import Parameter
|
5 |
+
|
6 |
+
from src.utils.no_grad import no_grad
|
7 |
+
from typing import Callable, Iterator, Tuple
|
8 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
from src.diffusion.base.training import *
|
11 |
+
from src.diffusion.base.scheduling import BaseScheduler
|
12 |
+
|
13 |
+
def inverse_sigma(alpha, sigma):
|
14 |
+
return 1/sigma**2
|
15 |
+
def snr(alpha, sigma):
|
16 |
+
return alpha/sigma
|
17 |
+
def minsnr(alpha, sigma, threshold=5):
|
18 |
+
return torch.clip(alpha/sigma, min=threshold)
|
19 |
+
def maxsnr(alpha, sigma, threshold=5):
|
20 |
+
return torch.clip(alpha/sigma, max=threshold)
|
21 |
+
def constant(alpha, sigma):
|
22 |
+
return 1
|
23 |
+
|
24 |
+
|
25 |
+
class DINOv2(nn.Module):
|
26 |
+
def __init__(self, weight_path:str):
|
27 |
+
super(DINOv2, self).__init__()
|
28 |
+
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
29 |
+
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
30 |
+
self.encoder.head = torch.nn.Identity()
|
31 |
+
self.patch_size = self.encoder.patch_embed.patch_size
|
32 |
+
self.precomputed_pos_embed = dict()
|
33 |
+
|
34 |
+
def fetch_pos(self, h, w):
|
35 |
+
key = (h, w)
|
36 |
+
if key in self.precomputed_pos_embed:
|
37 |
+
return self.precomputed_pos_embed[key]
|
38 |
+
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
39 |
+
self.pos_embed.data, [h, w],
|
40 |
+
)
|
41 |
+
self.precomputed_pos_embed[key] = value
|
42 |
+
return value
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
b, c, h, w = x.shape
|
46 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
47 |
+
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
48 |
+
b, c, h, w = x.shape
|
49 |
+
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
50 |
+
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
51 |
+
self.encoder.pos_embed.data = pos_embed_data
|
52 |
+
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
53 |
+
return feature
|
54 |
+
|
55 |
+
|
56 |
+
class REPATrainer(BaseTrainer):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
scheduler: BaseScheduler,
|
60 |
+
loss_weight_fn:Callable=constant,
|
61 |
+
feat_loss_weight: float=0.5,
|
62 |
+
lognorm_t=False,
|
63 |
+
encoder_weight_path=None,
|
64 |
+
align_layer=8,
|
65 |
+
proj_denoiser_dim=256,
|
66 |
+
proj_hidden_dim=256,
|
67 |
+
proj_encoder_dim=256,
|
68 |
+
*args,
|
69 |
+
**kwargs
|
70 |
+
):
|
71 |
+
super().__init__(*args, **kwargs)
|
72 |
+
self.lognorm_t = lognorm_t
|
73 |
+
self.scheduler = scheduler
|
74 |
+
self.loss_weight_fn = loss_weight_fn
|
75 |
+
self.feat_loss_weight = feat_loss_weight
|
76 |
+
self.align_layer = align_layer
|
77 |
+
self.encoder = DINOv2(encoder_weight_path)
|
78 |
+
no_grad(self.encoder)
|
79 |
+
|
80 |
+
self.proj = nn.Sequential(
|
81 |
+
nn.Sequential(
|
82 |
+
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
83 |
+
nn.SiLU(),
|
84 |
+
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
85 |
+
nn.SiLU(),
|
86 |
+
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
91 |
+
batch_size, c, height, width = x.shape
|
92 |
+
if self.lognorm_t:
|
93 |
+
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
94 |
+
else:
|
95 |
+
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
96 |
+
t = base_t
|
97 |
+
|
98 |
+
noise = torch.randn_like(x)
|
99 |
+
alpha = self.scheduler.alpha(t)
|
100 |
+
dalpha = self.scheduler.dalpha(t)
|
101 |
+
sigma = self.scheduler.sigma(t)
|
102 |
+
dsigma = self.scheduler.dsigma(t)
|
103 |
+
|
104 |
+
x_t = alpha * x + noise * sigma
|
105 |
+
v_t = dalpha * x + dsigma * noise
|
106 |
+
|
107 |
+
src_feature = []
|
108 |
+
def forward_hook(net, input, output):
|
109 |
+
src_feature.append(output)
|
110 |
+
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
111 |
+
|
112 |
+
out = net(x_t, t, y)
|
113 |
+
src_feature = self.proj(src_feature[0])
|
114 |
+
handle.remove()
|
115 |
+
|
116 |
+
with torch.no_grad():
|
117 |
+
dst_feature = self.encoder(raw_images)
|
118 |
+
|
119 |
+
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
120 |
+
cos_loss = 1 - cos_sim
|
121 |
+
|
122 |
+
weight = self.loss_weight_fn(alpha, sigma)
|
123 |
+
fm_loss = weight*(out - v_t)**2
|
124 |
+
|
125 |
+
out = dict(
|
126 |
+
fm_loss=fm_loss.mean(),
|
127 |
+
cos_loss=cos_loss.mean(),
|
128 |
+
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
129 |
+
)
|
130 |
+
return out
|
131 |
+
|
132 |
+
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
133 |
+
self.proj.state_dict(
|
134 |
+
destination=destination,
|
135 |
+
prefix=prefix + "proj.",
|
136 |
+
keep_vars=keep_vars)
|
137 |
+
|
src/diffusion/pre_integral.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# lagrange interpolation
|
4 |
+
def lagrange_preint_o1(t1, v1, int_t_start, int_t_end):
|
5 |
+
'''
|
6 |
+
lagrange interpolation of order 1
|
7 |
+
Args:
|
8 |
+
t1: timestepx
|
9 |
+
v1: value field at t1
|
10 |
+
int_t_start: intergation start time
|
11 |
+
int_t_end: intergation end time
|
12 |
+
Returns:
|
13 |
+
integrated value
|
14 |
+
'''
|
15 |
+
int1 = (int_t_end-int_t_start)
|
16 |
+
return int1*v1, (int1/int1, )
|
17 |
+
|
18 |
+
def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end):
|
19 |
+
'''
|
20 |
+
lagrange interpolation of order 2
|
21 |
+
Args:
|
22 |
+
t1: timestepx
|
23 |
+
t2: timestepy
|
24 |
+
v1: value field at t1
|
25 |
+
v2: value field at t2
|
26 |
+
int_t_start: intergation start time
|
27 |
+
int_t_end: intergation end time
|
28 |
+
Returns:
|
29 |
+
integrated value
|
30 |
+
'''
|
31 |
+
int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2)
|
32 |
+
int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2)
|
33 |
+
int_sum = int1+int2
|
34 |
+
return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum)
|
35 |
+
|
36 |
+
def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end):
|
37 |
+
'''
|
38 |
+
lagrange interpolation of order 3
|
39 |
+
Args:
|
40 |
+
t1: timestepx
|
41 |
+
t2: timestepy
|
42 |
+
t3: timestepz
|
43 |
+
v1: value field at t1
|
44 |
+
v2: value field at t2
|
45 |
+
v3: value field at t3
|
46 |
+
int_t_start: intergation start time
|
47 |
+
int_t_end: intergation end time
|
48 |
+
Returns:
|
49 |
+
integrated value
|
50 |
+
'''
|
51 |
+
int1_denom = (t1-t2)*(t1-t3)
|
52 |
+
int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end
|
53 |
+
int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start
|
54 |
+
int1 = (int1_end - int1_start)/int1_denom
|
55 |
+
int2_denom = (t2-t1)*(t2-t3)
|
56 |
+
int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end
|
57 |
+
int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start
|
58 |
+
int2 = (int2_end - int2_start)/int2_denom
|
59 |
+
int3_denom = (t3-t1)*(t3-t2)
|
60 |
+
int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end
|
61 |
+
int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start
|
62 |
+
int3 = (int3_end - int3_start)/int3_denom
|
63 |
+
int_sum = int1+int2+int3
|
64 |
+
return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum)
|
65 |
+
|
66 |
+
def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end):
|
67 |
+
'''
|
68 |
+
lagrange interpolation of order 4
|
69 |
+
Args:
|
70 |
+
t1: timestepx
|
71 |
+
t2: timestepy
|
72 |
+
t3: timestepz
|
73 |
+
t4: timestepw
|
74 |
+
v1: value field at t1
|
75 |
+
v2: value field at t2
|
76 |
+
v3: value field at t3
|
77 |
+
v4: value field at t4
|
78 |
+
int_t_start: intergation start time
|
79 |
+
int_t_end: intergation end time
|
80 |
+
Returns:
|
81 |
+
integrated value
|
82 |
+
'''
|
83 |
+
int1_denom = (t1-t2)*(t1-t3)*(t1-t4)
|
84 |
+
int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end
|
85 |
+
int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start
|
86 |
+
int1 = (int1_end - int1_start)/int1_denom
|
87 |
+
int2_denom = (t2-t1)*(t2-t3)*(t2-t4)
|
88 |
+
int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end
|
89 |
+
int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start
|
90 |
+
int2 = (int2_end - int2_start)/int2_denom
|
91 |
+
int3_denom = (t3-t1)*(t3-t2)*(t3-t4)
|
92 |
+
int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end
|
93 |
+
int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start
|
94 |
+
int3 = (int3_end - int3_start)/int3_denom
|
95 |
+
int4_denom = (t4-t1)*(t4-t2)*(t4-t3)
|
96 |
+
int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end
|
97 |
+
int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start
|
98 |
+
int4 = (int4_end - int4_start)/int4_denom
|
99 |
+
int_sum = int1+int2+int3+int4
|
100 |
+
return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum)
|
101 |
+
|
102 |
+
|
103 |
+
def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end):
|
104 |
+
'''
|
105 |
+
lagrange interpolation
|
106 |
+
Args:
|
107 |
+
order: order of interpolation
|
108 |
+
pre_vs: value field at pre_ts
|
109 |
+
pre_ts: timesteps
|
110 |
+
int_t_start: intergation start time
|
111 |
+
int_t_end: intergation end time
|
112 |
+
Returns:
|
113 |
+
integrated value
|
114 |
+
'''
|
115 |
+
order = min(order, len(pre_vs), len(pre_ts))
|
116 |
+
if order == 1:
|
117 |
+
return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end)
|
118 |
+
elif order == 2:
|
119 |
+
return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
120 |
+
elif order == 3:
|
121 |
+
return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
122 |
+
elif order == 4:
|
123 |
+
return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
124 |
+
else:
|
125 |
+
raise ValueError('Invalid order')
|
126 |
+
|
127 |
+
|
128 |
+
def polynomial_integral(coeffs, int_t_start, int_t_end):
|
129 |
+
'''
|
130 |
+
polynomial integral
|
131 |
+
Args:
|
132 |
+
coeffs: coefficients of the polynomial
|
133 |
+
int_t_start: intergation start time
|
134 |
+
int_t_end: intergation end time
|
135 |
+
Returns:
|
136 |
+
integrated value
|
137 |
+
'''
|
138 |
+
orders = len(coeffs)
|
139 |
+
int_val = 0
|
140 |
+
for o in range(orders):
|
141 |
+
int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1))
|
142 |
+
return int_val
|
143 |
+
|
src/diffusion/stateful_flow_matching/adam_sampling.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from src.diffusion.base.sampling import *
|
3 |
+
from src.diffusion.base.scheduling import *
|
4 |
+
from src.diffusion.pre_integral import *
|
5 |
+
|
6 |
+
from typing import Callable, List, Tuple
|
7 |
+
|
8 |
+
def ode_step_fn(x, v, dt, s, w):
|
9 |
+
return x + v * dt
|
10 |
+
|
11 |
+
def t2snr(t):
|
12 |
+
if isinstance(t, torch.Tensor):
|
13 |
+
return (t.clip(min=1e-8)/(1-t + 1e-8))
|
14 |
+
if isinstance(t, List) or isinstance(t, Tuple):
|
15 |
+
return [t2snr(t) for t in t]
|
16 |
+
t = max(t, 1e-8)
|
17 |
+
return (t/(1-t + 1e-8))
|
18 |
+
|
19 |
+
def t2logsnr(t):
|
20 |
+
if isinstance(t, torch.Tensor):
|
21 |
+
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
|
22 |
+
if isinstance(t, List) or isinstance(t, Tuple):
|
23 |
+
return [t2logsnr(t) for t in t]
|
24 |
+
t = max(t, 1e-3)
|
25 |
+
return math.log(t/(1-t + 1e-3))
|
26 |
+
|
27 |
+
def t2isnr(t):
|
28 |
+
return 1/t2snr(t)
|
29 |
+
|
30 |
+
def nop(t):
|
31 |
+
return t
|
32 |
+
|
33 |
+
def shift_respace_fn(t, shift=3.0):
|
34 |
+
return t / (t + (1 - t) * shift)
|
35 |
+
|
36 |
+
import logging
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
class AdamLMSampler(BaseSampler):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
order: int = 2,
|
43 |
+
timeshift: float = 1.0,
|
44 |
+
state_refresh_rate: int = 1,
|
45 |
+
lms_transform_fn: Callable = nop,
|
46 |
+
w_scheduler: BaseScheduler = None,
|
47 |
+
step_fn: Callable = ode_step_fn,
|
48 |
+
*args,
|
49 |
+
**kwargs
|
50 |
+
):
|
51 |
+
super().__init__(*args, **kwargs)
|
52 |
+
self.step_fn = step_fn
|
53 |
+
self.w_scheduler = w_scheduler
|
54 |
+
self.state_refresh_rate = state_refresh_rate
|
55 |
+
|
56 |
+
assert self.scheduler is not None
|
57 |
+
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
58 |
+
self.order = order
|
59 |
+
self.lms_transform_fn = lms_transform_fn
|
60 |
+
|
61 |
+
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
62 |
+
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
63 |
+
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
64 |
+
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
65 |
+
self._reparameterize_coeffs()
|
66 |
+
|
67 |
+
def _reparameterize_coeffs(self):
|
68 |
+
solver_coeffs = [[] for _ in range(self.num_steps)]
|
69 |
+
for i in range(0, self.num_steps):
|
70 |
+
pre_vs = [1.0, ]*(i+1)
|
71 |
+
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
|
72 |
+
int_t_start = self.lms_transform_fn(self.timesteps[i])
|
73 |
+
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
|
74 |
+
|
75 |
+
order_annealing = self.order #self.num_steps - i
|
76 |
+
order = min(self.order, i + 1, order_annealing)
|
77 |
+
|
78 |
+
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
|
79 |
+
solver_coeffs[i] = coeffs
|
80 |
+
self.solver_coeffs = solver_coeffs
|
81 |
+
|
82 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
83 |
+
"""
|
84 |
+
sampling process of Euler sampler
|
85 |
+
-
|
86 |
+
"""
|
87 |
+
batch_size = noise.shape[0]
|
88 |
+
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
89 |
+
x = x0 = noise
|
90 |
+
state = None
|
91 |
+
pred_trajectory = []
|
92 |
+
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
|
93 |
+
timedeltas = self.timedeltas
|
94 |
+
solver_coeffs = self.solver_coeffs
|
95 |
+
for i in range(self.num_steps):
|
96 |
+
cfg_x = torch.cat([x, x], dim=0)
|
97 |
+
cfg_t = t_cur.repeat(2)
|
98 |
+
if i % self.state_refresh_rate == 0:
|
99 |
+
state = None
|
100 |
+
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
101 |
+
out = self.guidance_fn(out, self.guidances[i])
|
102 |
+
pred_trajectory.append(out)
|
103 |
+
out = torch.zeros_like(out)
|
104 |
+
order = len(self.solver_coeffs[i])
|
105 |
+
for j in range(order):
|
106 |
+
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
|
107 |
+
v = out
|
108 |
+
dt = timedeltas[i]
|
109 |
+
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
|
110 |
+
x = self.step_fn(x, v, dt, s=0, w=0)
|
111 |
+
t_cur += dt
|
112 |
+
return x
|
src/diffusion/stateful_flow_matching/sampling.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from src.diffusion.base.guidance import *
|
4 |
+
from src.diffusion.base.scheduling import *
|
5 |
+
from src.diffusion.base.sampling import *
|
6 |
+
|
7 |
+
from typing import Callable
|
8 |
+
|
9 |
+
|
10 |
+
def shift_respace_fn(t, shift=3.0):
|
11 |
+
return t / (t + (1 - t) * shift)
|
12 |
+
|
13 |
+
def ode_step_fn(x, v, dt, s, w):
|
14 |
+
return x + v * dt
|
15 |
+
|
16 |
+
def sde_mean_step_fn(x, v, dt, s, w):
|
17 |
+
return x + v * dt + s * w * dt
|
18 |
+
|
19 |
+
def sde_step_fn(x, v, dt, s, w):
|
20 |
+
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
|
21 |
+
|
22 |
+
def sde_preserve_step_fn(x, v, dt, s, w):
|
23 |
+
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
|
24 |
+
|
25 |
+
|
26 |
+
import logging
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
class EulerSampler(BaseSampler):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
w_scheduler: BaseScheduler = None,
|
33 |
+
timeshift=1.0,
|
34 |
+
guidance_interval_min: float = 0.0,
|
35 |
+
guidance_interval_max: float = 1.0,
|
36 |
+
state_refresh_rate=1,
|
37 |
+
step_fn: Callable = ode_step_fn,
|
38 |
+
last_step=None,
|
39 |
+
last_step_fn: Callable = ode_step_fn,
|
40 |
+
*args,
|
41 |
+
**kwargs
|
42 |
+
):
|
43 |
+
super().__init__(*args, **kwargs)
|
44 |
+
self.step_fn = step_fn
|
45 |
+
self.last_step = last_step
|
46 |
+
self.last_step_fn = last_step_fn
|
47 |
+
self.w_scheduler = w_scheduler
|
48 |
+
self.timeshift = timeshift
|
49 |
+
self.state_refresh_rate = state_refresh_rate
|
50 |
+
self.guidance_interval_min = guidance_interval_min
|
51 |
+
self.guidance_interval_max = guidance_interval_max
|
52 |
+
|
53 |
+
if self.last_step is None or self.num_steps == 1:
|
54 |
+
self.last_step = 1.0 / self.num_steps
|
55 |
+
|
56 |
+
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
57 |
+
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
58 |
+
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
59 |
+
|
60 |
+
assert self.last_step > 0.0
|
61 |
+
assert self.scheduler is not None
|
62 |
+
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
63 |
+
if self.w_scheduler is not None:
|
64 |
+
if self.step_fn == ode_step_fn:
|
65 |
+
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
66 |
+
|
67 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
68 |
+
"""
|
69 |
+
sampling process of Euler sampler
|
70 |
+
-
|
71 |
+
"""
|
72 |
+
batch_size = noise.shape[0]
|
73 |
+
steps = self.timesteps.to(noise.device)
|
74 |
+
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
75 |
+
x = noise
|
76 |
+
state = None
|
77 |
+
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
78 |
+
dt = t_next - t_cur
|
79 |
+
t_cur = t_cur.repeat(batch_size)
|
80 |
+
sigma = self.scheduler.sigma(t_cur)
|
81 |
+
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
|
82 |
+
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
83 |
+
if self.w_scheduler:
|
84 |
+
w = self.w_scheduler.w(t_cur)
|
85 |
+
else:
|
86 |
+
w = 0.0
|
87 |
+
|
88 |
+
cfg_x = torch.cat([x, x], dim=0)
|
89 |
+
cfg_t = t_cur.repeat(2)
|
90 |
+
if i % self.state_refresh_rate == 0:
|
91 |
+
state = None
|
92 |
+
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
93 |
+
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
94 |
+
out = self.guidance_fn(out, self.guidance)
|
95 |
+
else:
|
96 |
+
out = self.guidance_fn(out, 1.0)
|
97 |
+
v = out
|
98 |
+
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
|
99 |
+
if i < self.num_steps -1 :
|
100 |
+
x = self.step_fn(x, v, dt, s=s, w=w)
|
101 |
+
else:
|
102 |
+
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
103 |
+
return x
|
src/diffusion/stateful_flow_matching/scheduling.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from src.diffusion.base.scheduling import *
|
4 |
+
|
5 |
+
|
6 |
+
class LinearScheduler(BaseScheduler):
|
7 |
+
def alpha(self, t) -> Tensor:
|
8 |
+
return (t).view(-1, 1, 1, 1)
|
9 |
+
def sigma(self, t) -> Tensor:
|
10 |
+
return (1-t).view(-1, 1, 1, 1)
|
11 |
+
def dalpha(self, t) -> Tensor:
|
12 |
+
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
|
13 |
+
def dsigma(self, t) -> Tensor:
|
14 |
+
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
|
15 |
+
|
16 |
+
# SoTA for ImageNet!
|
17 |
+
class GVPScheduler(BaseScheduler):
|
18 |
+
def alpha(self, t) -> Tensor:
|
19 |
+
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
20 |
+
def sigma(self, t) -> Tensor:
|
21 |
+
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
22 |
+
def dalpha(self, t) -> Tensor:
|
23 |
+
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
24 |
+
def dsigma(self, t) -> Tensor:
|
25 |
+
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
26 |
+
def w(self, t):
|
27 |
+
return torch.sin(t)**2
|
28 |
+
|
29 |
+
class ConstScheduler(BaseScheduler):
|
30 |
+
def w(self, t):
|
31 |
+
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
|
32 |
+
|
33 |
+
from src.diffusion.ddpm.scheduling import VPScheduler
|
34 |
+
class VPBetaScheduler(VPScheduler):
|
35 |
+
def w(self, t):
|
36 |
+
return self.beta(t).view(-1, 1, 1, 1)
|
37 |
+
|
38 |
+
|
39 |
+
|
src/diffusion/stateful_flow_matching/sharing_sampling.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from src.diffusion.base.guidance import *
|
4 |
+
from src.diffusion.base.scheduling import *
|
5 |
+
from src.diffusion.base.sampling import *
|
6 |
+
|
7 |
+
from typing import Callable
|
8 |
+
|
9 |
+
|
10 |
+
def shift_respace_fn(t, shift=3.0):
|
11 |
+
return t / (t + (1 - t) * shift)
|
12 |
+
|
13 |
+
def ode_step_fn(x, v, dt, s, w):
|
14 |
+
return x + v * dt
|
15 |
+
|
16 |
+
|
17 |
+
import logging
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
class EulerSampler(BaseSampler):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
w_scheduler: BaseScheduler = None,
|
24 |
+
timeshift=1.0,
|
25 |
+
guidance_interval_min: float = 0.0,
|
26 |
+
guidance_interval_max: float = 1.0,
|
27 |
+
state_refresh_rate=1,
|
28 |
+
step_fn: Callable = ode_step_fn,
|
29 |
+
last_step=None,
|
30 |
+
last_step_fn: Callable = ode_step_fn,
|
31 |
+
*args,
|
32 |
+
**kwargs
|
33 |
+
):
|
34 |
+
super().__init__(*args, **kwargs)
|
35 |
+
self.step_fn = step_fn
|
36 |
+
self.last_step = last_step
|
37 |
+
self.last_step_fn = last_step_fn
|
38 |
+
self.w_scheduler = w_scheduler
|
39 |
+
self.timeshift = timeshift
|
40 |
+
self.state_refresh_rate = state_refresh_rate
|
41 |
+
self.guidance_interval_min = guidance_interval_min
|
42 |
+
self.guidance_interval_max = guidance_interval_max
|
43 |
+
|
44 |
+
if self.last_step is None or self.num_steps == 1:
|
45 |
+
self.last_step = 1.0 / self.num_steps
|
46 |
+
|
47 |
+
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
48 |
+
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
49 |
+
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
50 |
+
|
51 |
+
assert self.last_step > 0.0
|
52 |
+
assert self.scheduler is not None
|
53 |
+
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
54 |
+
if self.w_scheduler is not None:
|
55 |
+
if self.step_fn == ode_step_fn:
|
56 |
+
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
57 |
+
|
58 |
+
# init recompute
|
59 |
+
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
|
60 |
+
self.recompute_timesteps = list(range(self.num_steps))
|
61 |
+
|
62 |
+
def sharing_dp(self, net, noise, condition, uncondition):
|
63 |
+
_, C, H, W = noise.shape
|
64 |
+
B = 8
|
65 |
+
template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
|
66 |
+
template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
|
67 |
+
template_uncondition = torch.full((B, ), 1000, device=condition.device)
|
68 |
+
_, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
|
69 |
+
states = torch.stack(state_list)
|
70 |
+
N, B, L, C = states.shape
|
71 |
+
states = states.view(N, B*L, C )
|
72 |
+
states = states.permute(1, 0, 2)
|
73 |
+
states = torch.nn.functional.normalize(states, dim=-1)
|
74 |
+
with torch.autocast(device_type="cuda", dtype=torch.float64):
|
75 |
+
sim = torch.bmm(states, states.transpose(1, 2))
|
76 |
+
sim = torch.mean(sim, dim=0).cpu()
|
77 |
+
error_map = (1-sim).tolist()
|
78 |
+
|
79 |
+
# init cum-error
|
80 |
+
for i in range(1, self.num_steps):
|
81 |
+
for j in range(0, i):
|
82 |
+
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
|
83 |
+
|
84 |
+
# init dp and force 0 start
|
85 |
+
C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
|
86 |
+
P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
|
87 |
+
for i in range(1, self.num_steps+1):
|
88 |
+
C[1][i] = error_map[i - 1][0]
|
89 |
+
P[1][i] = 0
|
90 |
+
|
91 |
+
# dp state
|
92 |
+
for step in range(2, self.num_recompute_timesteps+1):
|
93 |
+
for i in range(step, self.num_steps+1):
|
94 |
+
min_value = 99999
|
95 |
+
min_index = -1
|
96 |
+
for j in range(step-1, i):
|
97 |
+
value = C[step-1][j] + error_map[i-1][j]
|
98 |
+
if value < min_value:
|
99 |
+
min_value = value
|
100 |
+
min_index = j
|
101 |
+
C[step][i] = min_value
|
102 |
+
P[step][i] = min_index
|
103 |
+
|
104 |
+
# trace back
|
105 |
+
timesteps = [self.num_steps,]
|
106 |
+
for i in range(self.num_recompute_timesteps, 0, -1):
|
107 |
+
idx = timesteps[-1]
|
108 |
+
timesteps.append(P[i][idx])
|
109 |
+
timesteps.reverse()
|
110 |
+
|
111 |
+
print("recompute timesteps solved by DP: ", timesteps)
|
112 |
+
return timesteps[:-1]
|
113 |
+
|
114 |
+
def _impl_sampling(self, net, noise, condition, uncondition):
|
115 |
+
"""
|
116 |
+
sampling process of Euler sampler
|
117 |
+
-
|
118 |
+
"""
|
119 |
+
batch_size = noise.shape[0]
|
120 |
+
steps = self.timesteps.to(noise.device)
|
121 |
+
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
122 |
+
x = noise
|
123 |
+
state = None
|
124 |
+
pooled_state_list = []
|
125 |
+
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
126 |
+
dt = t_next - t_cur
|
127 |
+
t_cur = t_cur.repeat(batch_size)
|
128 |
+
cfg_x = torch.cat([x, x], dim=0)
|
129 |
+
cfg_t = t_cur.repeat(2)
|
130 |
+
if i in self.recompute_timesteps:
|
131 |
+
state = None
|
132 |
+
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
133 |
+
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
134 |
+
out = self.guidance_fn(out, self.guidance)
|
135 |
+
else:
|
136 |
+
out = self.guidance_fn(out, 1.0)
|
137 |
+
v = out
|
138 |
+
if i < self.num_steps -1 :
|
139 |
+
x = self.step_fn(x, v, dt, s=0.0, w=0.0)
|
140 |
+
else:
|
141 |
+
x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
|
142 |
+
pooled_state_list.append(state)
|
143 |
+
return x, pooled_state_list
|
144 |
+
|
145 |
+
def __call__(self, net, noise, condition, uncondition):
|
146 |
+
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
|
147 |
+
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
|
148 |
+
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
|
149 |
+
return denoised
|
src/diffusion/stateful_flow_matching/training.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Callable
|
3 |
+
from src.diffusion.base.training import *
|
4 |
+
from src.diffusion.base.scheduling import BaseScheduler
|
5 |
+
|
6 |
+
def inverse_sigma(alpha, sigma):
|
7 |
+
return 1/sigma**2
|
8 |
+
def snr(alpha, sigma):
|
9 |
+
return alpha/sigma
|
10 |
+
def minsnr(alpha, sigma, threshold=5):
|
11 |
+
return torch.clip(alpha/sigma, min=threshold)
|
12 |
+
def maxsnr(alpha, sigma, threshold=5):
|
13 |
+
return torch.clip(alpha/sigma, max=threshold)
|
14 |
+
def constant(alpha, sigma):
|
15 |
+
return 1
|
16 |
+
|
17 |
+
class FlowMatchingTrainer(BaseTrainer):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
scheduler: BaseScheduler,
|
21 |
+
loss_weight_fn:Callable=constant,
|
22 |
+
lognorm_t=False,
|
23 |
+
*args,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
super().__init__(*args, **kwargs)
|
27 |
+
self.lognorm_t = lognorm_t
|
28 |
+
self.scheduler = scheduler
|
29 |
+
self.loss_weight_fn = loss_weight_fn
|
30 |
+
|
31 |
+
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
32 |
+
batch_size = x.shape[0]
|
33 |
+
if self.lognorm_t:
|
34 |
+
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
35 |
+
else:
|
36 |
+
t = torch.rand(batch_size).to(x.device, x.dtype)
|
37 |
+
noise = torch.randn_like(x)
|
38 |
+
alpha = self.scheduler.alpha(t)
|
39 |
+
dalpha = self.scheduler.dalpha(t)
|
40 |
+
sigma = self.scheduler.sigma(t)
|
41 |
+
dsigma = self.scheduler.dsigma(t)
|
42 |
+
w = self.scheduler.w(t)
|
43 |
+
|
44 |
+
x_t = alpha * x + noise * sigma
|
45 |
+
v_t = dalpha * x + dsigma * noise
|
46 |
+
out, _ = net(x_t, t, y)
|
47 |
+
|
48 |
+
weight = self.loss_weight_fn(alpha, sigma)
|
49 |
+
|
50 |
+
loss = weight*(out - v_t)**2
|
51 |
+
|
52 |
+
out = dict(
|
53 |
+
loss=loss.mean(),
|
54 |
+
)
|
55 |
+
return out
|
src/diffusion/stateful_flow_matching/training_repa.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import timm
|
4 |
+
from torch.nn import Parameter
|
5 |
+
|
6 |
+
from src.utils.no_grad import no_grad
|
7 |
+
from typing import Callable, Iterator, Tuple
|
8 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
from src.diffusion.base.training import *
|
11 |
+
from src.diffusion.base.scheduling import BaseScheduler
|
12 |
+
|
13 |
+
def inverse_sigma(alpha, sigma):
|
14 |
+
return 1/sigma**2
|
15 |
+
def snr(alpha, sigma):
|
16 |
+
return alpha/sigma
|
17 |
+
def minsnr(alpha, sigma, threshold=5):
|
18 |
+
return torch.clip(alpha/sigma, min=threshold)
|
19 |
+
def maxsnr(alpha, sigma, threshold=5):
|
20 |
+
return torch.clip(alpha/sigma, max=threshold)
|
21 |
+
def constant(alpha, sigma):
|
22 |
+
return 1
|
23 |
+
|
24 |
+
|
25 |
+
class DINOv2(nn.Module):
|
26 |
+
def __init__(self, weight_path:str):
|
27 |
+
super(DINOv2, self).__init__()
|
28 |
+
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
29 |
+
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
30 |
+
self.encoder.head = torch.nn.Identity()
|
31 |
+
self.patch_size = self.encoder.patch_embed.patch_size
|
32 |
+
self.precomputed_pos_embed = dict()
|
33 |
+
|
34 |
+
def fetch_pos(self, h, w):
|
35 |
+
key = (h, w)
|
36 |
+
if key in self.precomputed_pos_embed:
|
37 |
+
return self.precomputed_pos_embed[key]
|
38 |
+
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
39 |
+
self.pos_embed.data, [h, w],
|
40 |
+
)
|
41 |
+
self.precomputed_pos_embed[key] = value
|
42 |
+
return value
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
b, c, h, w = x.shape
|
46 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
47 |
+
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
48 |
+
b, c, h, w = x.shape
|
49 |
+
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
50 |
+
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
51 |
+
self.encoder.pos_embed.data = pos_embed_data
|
52 |
+
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
53 |
+
return feature
|
54 |
+
|
55 |
+
|
56 |
+
class REPATrainer(BaseTrainer):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
scheduler: BaseScheduler,
|
60 |
+
loss_weight_fn:Callable=constant,
|
61 |
+
feat_loss_weight: float=0.5,
|
62 |
+
lognorm_t=False,
|
63 |
+
encoder_weight_path=None,
|
64 |
+
align_layer=8,
|
65 |
+
proj_denoiser_dim=256,
|
66 |
+
proj_hidden_dim=256,
|
67 |
+
proj_encoder_dim=256,
|
68 |
+
*args,
|
69 |
+
**kwargs
|
70 |
+
):
|
71 |
+
super().__init__(*args, **kwargs)
|
72 |
+
self.lognorm_t = lognorm_t
|
73 |
+
self.scheduler = scheduler
|
74 |
+
self.loss_weight_fn = loss_weight_fn
|
75 |
+
self.feat_loss_weight = feat_loss_weight
|
76 |
+
self.align_layer = align_layer
|
77 |
+
self.encoder = DINOv2(encoder_weight_path)
|
78 |
+
self.proj_encoder_dim = proj_encoder_dim
|
79 |
+
no_grad(self.encoder)
|
80 |
+
|
81 |
+
self.proj = nn.Sequential(
|
82 |
+
nn.Sequential(
|
83 |
+
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
84 |
+
nn.SiLU(),
|
85 |
+
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
86 |
+
nn.SiLU(),
|
87 |
+
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
88 |
+
)
|
89 |
+
)
|
90 |
+
|
91 |
+
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
92 |
+
batch_size, c, height, width = x.shape
|
93 |
+
if self.lognorm_t:
|
94 |
+
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
95 |
+
else:
|
96 |
+
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
97 |
+
t = base_t
|
98 |
+
|
99 |
+
noise = torch.randn_like(x)
|
100 |
+
alpha = self.scheduler.alpha(t)
|
101 |
+
dalpha = self.scheduler.dalpha(t)
|
102 |
+
sigma = self.scheduler.sigma(t)
|
103 |
+
dsigma = self.scheduler.dsigma(t)
|
104 |
+
|
105 |
+
x_t = alpha * x + noise * sigma
|
106 |
+
v_t = dalpha * x + dsigma * noise
|
107 |
+
src_feature = []
|
108 |
+
def forward_hook(net, input, output):
|
109 |
+
src_feature.append(output)
|
110 |
+
|
111 |
+
if getattr(net, "blocks", None) is not None:
|
112 |
+
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
113 |
+
else:
|
114 |
+
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
115 |
+
|
116 |
+
out, _ = net(x_t, t, y)
|
117 |
+
src_feature = self.proj(src_feature[0])
|
118 |
+
handle.remove()
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
dst_feature = self.encoder(raw_images)
|
122 |
+
|
123 |
+
if dst_feature.shape[1] != src_feature.shape[1]:
|
124 |
+
dst_length = dst_feature.shape[1]
|
125 |
+
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
|
126 |
+
dst_height = (dst_length)**0.5 * (height/width)**0.5
|
127 |
+
dst_width = (dst_length)**0.5 * (width/height)**0.5
|
128 |
+
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
|
129 |
+
dst_feature = dst_feature.permute(0, 3, 1, 2)
|
130 |
+
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
|
131 |
+
dst_feature = dst_feature.permute(0, 2, 3, 1)
|
132 |
+
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
|
133 |
+
|
134 |
+
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
135 |
+
cos_loss = 1 - cos_sim
|
136 |
+
|
137 |
+
weight = self.loss_weight_fn(alpha, sigma)
|
138 |
+
fm_loss = weight*(out - v_t)**2
|
139 |
+
|
140 |
+
out = dict(
|
141 |
+
fm_loss=fm_loss.mean(),
|
142 |
+
cos_loss=cos_loss.mean(),
|
143 |
+
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
144 |
+
)
|
145 |
+
return out
|
146 |
+
|
147 |
+
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
148 |
+
self.proj.state_dict(
|
149 |
+
destination=destination,
|
150 |
+
prefix=prefix + "proj.",
|
151 |
+
keep_vars=keep_vars)
|
152 |
+
|
src/lightning_data.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
import lightning.pytorch as pl
|
5 |
+
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from src.data.dataset.randn import RandomNDataset
|
8 |
+
from src.data.var_training import VARTransformEngine
|
9 |
+
|
10 |
+
def collate_fn(batch):
|
11 |
+
new_batch = copy.deepcopy(batch)
|
12 |
+
new_batch = list(zip(*new_batch))
|
13 |
+
for i in range(len(new_batch)):
|
14 |
+
if isinstance(new_batch[i][0], torch.Tensor):
|
15 |
+
try:
|
16 |
+
new_batch[i] = torch.stack(new_batch[i], dim=0)
|
17 |
+
except:
|
18 |
+
print("Warning: could not stack tensors")
|
19 |
+
return new_batch
|
20 |
+
|
21 |
+
class DataModule(pl.LightningDataModule):
|
22 |
+
def __init__(self,
|
23 |
+
train_root,
|
24 |
+
test_nature_root,
|
25 |
+
test_gen_root,
|
26 |
+
train_image_size=64,
|
27 |
+
train_batch_size=64,
|
28 |
+
train_num_workers=8,
|
29 |
+
var_transform_engine: VARTransformEngine = None,
|
30 |
+
train_prefetch_factor=2,
|
31 |
+
train_dataset: str = None,
|
32 |
+
eval_batch_size=32,
|
33 |
+
eval_num_workers=4,
|
34 |
+
eval_max_num_instances=50000,
|
35 |
+
pred_batch_size=32,
|
36 |
+
pred_num_workers=4,
|
37 |
+
pred_seeds:str=None,
|
38 |
+
pred_selected_classes=None,
|
39 |
+
num_classes=1000,
|
40 |
+
latent_shape=(4,64,64),
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None
|
44 |
+
|
45 |
+
self.train_root = train_root
|
46 |
+
self.train_image_size = train_image_size
|
47 |
+
self.train_dataset = train_dataset
|
48 |
+
# stupid data_convert override, just to make nebular happy
|
49 |
+
self.train_batch_size = train_batch_size
|
50 |
+
self.train_num_workers = train_num_workers
|
51 |
+
self.train_prefetch_factor = train_prefetch_factor
|
52 |
+
|
53 |
+
self.test_nature_root = test_nature_root
|
54 |
+
self.test_gen_root = test_gen_root
|
55 |
+
self.eval_max_num_instances = eval_max_num_instances
|
56 |
+
self.pred_seeds = pred_seeds
|
57 |
+
self.num_classes = num_classes
|
58 |
+
self.latent_shape = latent_shape
|
59 |
+
|
60 |
+
self.eval_batch_size = eval_batch_size
|
61 |
+
self.pred_batch_size = pred_batch_size
|
62 |
+
|
63 |
+
self.pred_num_workers = pred_num_workers
|
64 |
+
self.eval_num_workers = eval_num_workers
|
65 |
+
|
66 |
+
self.pred_selected_classes = pred_selected_classes
|
67 |
+
|
68 |
+
self._train_dataloader = None
|
69 |
+
self.var_transform_engine = var_transform_engine
|
70 |
+
|
71 |
+
def setup(self, stage: str) -> None:
|
72 |
+
if stage == "fit":
|
73 |
+
assert self.train_dataset is not None
|
74 |
+
if self.train_dataset == "pix_imagenet64":
|
75 |
+
from src.data.dataset.imagenet import PixImageNet64
|
76 |
+
self.train_dataset = PixImageNet64(
|
77 |
+
root=self.train_root,
|
78 |
+
)
|
79 |
+
elif self.train_dataset == "pix_imagenet128":
|
80 |
+
from src.data.dataset.imagenet import PixImageNet128
|
81 |
+
self.train_dataset = PixImageNet128(
|
82 |
+
root=self.train_root,
|
83 |
+
)
|
84 |
+
elif self.train_dataset == "imagenet256":
|
85 |
+
from src.data.dataset.imagenet import ImageNet256
|
86 |
+
self.train_dataset = ImageNet256(
|
87 |
+
root=self.train_root,
|
88 |
+
)
|
89 |
+
elif self.train_dataset == "pix_imagenet256":
|
90 |
+
from src.data.dataset.imagenet import PixImageNet256
|
91 |
+
self.train_dataset = PixImageNet256(
|
92 |
+
root=self.train_root,
|
93 |
+
)
|
94 |
+
elif self.train_dataset == "imagenet512":
|
95 |
+
from src.data.dataset.imagenet import ImageNet512
|
96 |
+
self.train_dataset = ImageNet512(
|
97 |
+
root=self.train_root,
|
98 |
+
)
|
99 |
+
elif self.train_dataset == "pix_imagenet512":
|
100 |
+
from src.data.dataset.imagenet import PixImageNet512
|
101 |
+
self.train_dataset = PixImageNet512(
|
102 |
+
root=self.train_root,
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
raise NotImplementedError("no such dataset")
|
106 |
+
|
107 |
+
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
108 |
+
if self.var_transform_engine and self.trainer.training:
|
109 |
+
batch = self.var_transform_engine(batch)
|
110 |
+
return batch
|
111 |
+
|
112 |
+
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
113 |
+
global_rank = self.trainer.global_rank
|
114 |
+
world_size = self.trainer.world_size
|
115 |
+
from torch.utils.data import DistributedSampler
|
116 |
+
sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
|
117 |
+
self._train_dataloader = DataLoader(
|
118 |
+
self.train_dataset,
|
119 |
+
self.train_batch_size,
|
120 |
+
timeout=6000,
|
121 |
+
num_workers=self.train_num_workers,
|
122 |
+
prefetch_factor=self.train_prefetch_factor,
|
123 |
+
sampler=sampler,
|
124 |
+
collate_fn=collate_fn,
|
125 |
+
)
|
126 |
+
return self._train_dataloader
|
127 |
+
|
128 |
+
def val_dataloader(self) -> EVAL_DATALOADERS:
|
129 |
+
global_rank = self.trainer.global_rank
|
130 |
+
world_size = self.trainer.world_size
|
131 |
+
self.eval_dataset = RandomNDataset(
|
132 |
+
latent_shape=self.latent_shape,
|
133 |
+
num_classes=self.num_classes,
|
134 |
+
max_num_instances=self.eval_max_num_instances,
|
135 |
+
)
|
136 |
+
from torch.utils.data import DistributedSampler
|
137 |
+
sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
|
138 |
+
return DataLoader(self.eval_dataset, self.eval_batch_size,
|
139 |
+
num_workers=self.eval_num_workers,
|
140 |
+
prefetch_factor=2,
|
141 |
+
collate_fn=collate_fn,
|
142 |
+
sampler=sampler
|
143 |
+
)
|
144 |
+
|
145 |
+
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
146 |
+
global_rank = self.trainer.global_rank
|
147 |
+
world_size = self.trainer.world_size
|
148 |
+
self.pred_dataset = RandomNDataset(
|
149 |
+
seeds= self.pred_seeds,
|
150 |
+
max_num_instances=50000,
|
151 |
+
num_classes=self.num_classes,
|
152 |
+
selected_classes=self.pred_selected_classes,
|
153 |
+
latent_shape=self.latent_shape,
|
154 |
+
)
|
155 |
+
from torch.utils.data import DistributedSampler
|
156 |
+
sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
|
157 |
+
return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size,
|
158 |
+
num_workers=self.pred_num_workers,
|
159 |
+
prefetch_factor=4,
|
160 |
+
collate_fn=collate_fn,
|
161 |
+
sampler=sampler
|
162 |
+
)
|
src/lightning_model.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict
|
2 |
+
import os.path
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import lightning.pytorch as pl
|
7 |
+
from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT
|
8 |
+
from torch.optim.lr_scheduler import LRScheduler
|
9 |
+
from torch.optim import Optimizer
|
10 |
+
from lightning.pytorch.callbacks import Callback
|
11 |
+
|
12 |
+
|
13 |
+
from src.models.vae import BaseVAE, fp2uint8
|
14 |
+
from src.models.conditioner import BaseConditioner
|
15 |
+
from src.utils.model_loader import ModelLoader
|
16 |
+
from src.callbacks.simple_ema import SimpleEMA
|
17 |
+
from src.diffusion.base.sampling import BaseSampler
|
18 |
+
from src.diffusion.base.training import BaseTrainer
|
19 |
+
from src.utils.no_grad import no_grad, filter_nograd_tensors
|
20 |
+
from src.utils.copy import copy_params
|
21 |
+
|
22 |
+
EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA]
|
23 |
+
OptimizerCallable = Callable[[Iterable], Optimizer]
|
24 |
+
LRSchedulerCallable = Callable[[Optimizer], LRScheduler]
|
25 |
+
|
26 |
+
|
27 |
+
class LightningModel(pl.LightningModule):
|
28 |
+
def __init__(self,
|
29 |
+
vae: BaseVAE,
|
30 |
+
conditioner: BaseConditioner,
|
31 |
+
denoiser: nn.Module,
|
32 |
+
diffusion_trainer: BaseTrainer,
|
33 |
+
diffusion_sampler: BaseSampler,
|
34 |
+
ema_tracker: Optional[EMACallable] = None,
|
35 |
+
optimizer: OptimizerCallable = None,
|
36 |
+
lr_scheduler: LRSchedulerCallable = None,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
self.vae = vae
|
40 |
+
self.conditioner = conditioner
|
41 |
+
self.denoiser = denoiser
|
42 |
+
self.ema_denoiser = copy.deepcopy(self.denoiser)
|
43 |
+
self.diffusion_sampler = diffusion_sampler
|
44 |
+
self.diffusion_trainer = diffusion_trainer
|
45 |
+
self.ema_tracker = ema_tracker
|
46 |
+
self.optimizer = optimizer
|
47 |
+
self.lr_scheduler = lr_scheduler
|
48 |
+
# self.model_loader = ModelLoader()
|
49 |
+
|
50 |
+
self._strict_loading = False
|
51 |
+
|
52 |
+
def configure_model(self) -> None:
|
53 |
+
self.trainer.strategy.barrier()
|
54 |
+
# self.denoiser = self.model_loader.load(self.denoiser)
|
55 |
+
copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser)
|
56 |
+
|
57 |
+
# self.denoiser = torch.compile(self.denoiser)
|
58 |
+
# disable grad for conditioner and vae
|
59 |
+
no_grad(self.conditioner)
|
60 |
+
no_grad(self.vae)
|
61 |
+
no_grad(self.diffusion_sampler)
|
62 |
+
no_grad(self.ema_denoiser)
|
63 |
+
|
64 |
+
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
65 |
+
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
66 |
+
return [ema_tracker]
|
67 |
+
|
68 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
69 |
+
params_denoiser = filter_nograd_tensors(self.denoiser.parameters())
|
70 |
+
params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters())
|
71 |
+
optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser])
|
72 |
+
if self.lr_scheduler is None:
|
73 |
+
return dict(
|
74 |
+
optimizer=optimizer
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
lr_scheduler = self.lr_scheduler(optimizer)
|
78 |
+
return dict(
|
79 |
+
optimizer=optimizer,
|
80 |
+
lr_scheduler=lr_scheduler
|
81 |
+
)
|
82 |
+
|
83 |
+
def training_step(self, batch, batch_idx):
|
84 |
+
raw_images, x, y = batch
|
85 |
+
with torch.no_grad():
|
86 |
+
x = self.vae.encode(x)
|
87 |
+
condition, uncondition = self.conditioner(y)
|
88 |
+
loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition)
|
89 |
+
self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False)
|
90 |
+
return loss["loss"]
|
91 |
+
|
92 |
+
def predict_step(self, batch, batch_idx):
|
93 |
+
xT, y, metadata = batch
|
94 |
+
with torch.no_grad():
|
95 |
+
condition, uncondition = self.conditioner(y)
|
96 |
+
# Sample images:
|
97 |
+
samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition)
|
98 |
+
samples = self.vae.decode(samples)
|
99 |
+
# fp32 -1,1 -> uint8 0,255
|
100 |
+
samples = fp2uint8(samples)
|
101 |
+
return samples
|
102 |
+
|
103 |
+
def validation_step(self, batch, batch_idx):
|
104 |
+
samples = self.predict_step(batch, batch_idx)
|
105 |
+
return samples
|
106 |
+
|
107 |
+
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
108 |
+
if destination is None:
|
109 |
+
destination = {}
|
110 |
+
self._save_to_state_dict(destination, prefix, keep_vars)
|
111 |
+
self.denoiser.state_dict(
|
112 |
+
destination=destination,
|
113 |
+
prefix=prefix+"denoiser.",
|
114 |
+
keep_vars=keep_vars)
|
115 |
+
self.ema_denoiser.state_dict(
|
116 |
+
destination=destination,
|
117 |
+
prefix=prefix+"ema_denoiser.",
|
118 |
+
keep_vars=keep_vars)
|
119 |
+
self.diffusion_trainer.state_dict(
|
120 |
+
destination=destination,
|
121 |
+
prefix=prefix+"diffusion_trainer.",
|
122 |
+
keep_vars=keep_vars)
|
123 |
+
return destination
|
src/models/__init__.py
ADDED
File without changes
|
src/models/conditioner.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class BaseConditioner(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(BaseConditioner, self).__init__()
|
7 |
+
|
8 |
+
def _impl_condition(self, y):
|
9 |
+
...
|
10 |
+
def _impl_uncondition(self, y):
|
11 |
+
...
|
12 |
+
def __call__(self, y):
|
13 |
+
condition = self._impl_condition(y)
|
14 |
+
uncondition = self._impl_uncondition(y)
|
15 |
+
return condition, uncondition
|
16 |
+
|
17 |
+
class LabelConditioner(BaseConditioner):
|
18 |
+
def __init__(self, null_class):
|
19 |
+
super().__init__()
|
20 |
+
self.null_condition = null_class
|
21 |
+
|
22 |
+
def _impl_condition(self, y):
|
23 |
+
return torch.tensor(y).long().cuda()
|
24 |
+
|
25 |
+
def _impl_uncondition(self, y):
|
26 |
+
return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda()
|
src/models/denoiser/__init__.py
ADDED
File without changes
|
src/models/denoiser/decoupled_improved_dit.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from typing import Tuple
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import math
|
6 |
+
|
7 |
+
from torch.nn.init import zeros_
|
8 |
+
from torch.nn.modules.module import T
|
9 |
+
|
10 |
+
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
11 |
+
from torch.nn.functional import scaled_dot_product_attention
|
12 |
+
|
13 |
+
|
14 |
+
def modulate(x, shift, scale):
|
15 |
+
return x * (1 + scale) + shift
|
16 |
+
|
17 |
+
class Embed(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_chans: int = 3,
|
21 |
+
embed_dim: int = 768,
|
22 |
+
norm_layer = None,
|
23 |
+
bias: bool = True,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.in_chans = in_chans
|
27 |
+
self.embed_dim = embed_dim
|
28 |
+
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
29 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.proj(x)
|
32 |
+
x = self.norm(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
class TimestepEmbedder(nn.Module):
|
36 |
+
|
37 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
38 |
+
super().__init__()
|
39 |
+
self.mlp = nn.Sequential(
|
40 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
41 |
+
nn.SiLU(),
|
42 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
43 |
+
)
|
44 |
+
self.frequency_embedding_size = frequency_embedding_size
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def timestep_embedding(t, dim, max_period=10):
|
48 |
+
half = dim // 2
|
49 |
+
freqs = torch.exp(
|
50 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
51 |
+
)
|
52 |
+
args = t[..., None].float() * freqs[None, ...]
|
53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
54 |
+
if dim % 2:
|
55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
56 |
+
return embedding
|
57 |
+
|
58 |
+
def forward(self, t):
|
59 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
60 |
+
t_emb = self.mlp(t_freq)
|
61 |
+
return t_emb
|
62 |
+
|
63 |
+
class LabelEmbedder(nn.Module):
|
64 |
+
def __init__(self, num_classes, hidden_size):
|
65 |
+
super().__init__()
|
66 |
+
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
67 |
+
self.num_classes = num_classes
|
68 |
+
|
69 |
+
def forward(self, labels,):
|
70 |
+
embeddings = self.embedding_table(labels)
|
71 |
+
return embeddings
|
72 |
+
|
73 |
+
class FinalLayer(nn.Module):
|
74 |
+
def __init__(self, hidden_size, out_channels):
|
75 |
+
super().__init__()
|
76 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
77 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
78 |
+
self.adaLN_modulation = nn.Sequential(
|
79 |
+
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, x, c):
|
83 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
84 |
+
x = modulate(self.norm_final(x), shift, scale)
|
85 |
+
x = self.linear(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
class RMSNorm(nn.Module):
|
89 |
+
def __init__(self, hidden_size, eps=1e-6):
|
90 |
+
"""
|
91 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
95 |
+
self.variance_epsilon = eps
|
96 |
+
|
97 |
+
def forward(self, hidden_states):
|
98 |
+
input_dtype = hidden_states.dtype
|
99 |
+
hidden_states = hidden_states.to(torch.float32)
|
100 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
101 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
102 |
+
return self.weight * hidden_states.to(input_dtype)
|
103 |
+
|
104 |
+
class FeedForward(nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
dim: int,
|
108 |
+
hidden_dim: int,
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
112 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
113 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
114 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
115 |
+
def forward(self, x):
|
116 |
+
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
117 |
+
return x
|
118 |
+
|
119 |
+
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
120 |
+
# assert H * H == end
|
121 |
+
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
122 |
+
x_pos = torch.linspace(0, scale, width)
|
123 |
+
y_pos = torch.linspace(0, scale, height)
|
124 |
+
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
125 |
+
y_pos = y_pos.reshape(-1)
|
126 |
+
x_pos = x_pos.reshape(-1)
|
127 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
128 |
+
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
129 |
+
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
130 |
+
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
131 |
+
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
132 |
+
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
133 |
+
freqs_cis = freqs_cis.reshape(height*width, -1)
|
134 |
+
return freqs_cis
|
135 |
+
|
136 |
+
|
137 |
+
def apply_rotary_emb(
|
138 |
+
xq: torch.Tensor,
|
139 |
+
xk: torch.Tensor,
|
140 |
+
freqs_cis: torch.Tensor,
|
141 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
142 |
+
freqs_cis = freqs_cis[None, :, None, :]
|
143 |
+
# xq : B N H Hc
|
144 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
145 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
146 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
147 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
148 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
149 |
+
|
150 |
+
|
151 |
+
class RAttention(nn.Module):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
dim: int,
|
155 |
+
num_heads: int = 8,
|
156 |
+
qkv_bias: bool = False,
|
157 |
+
qk_norm: bool = True,
|
158 |
+
attn_drop: float = 0.,
|
159 |
+
proj_drop: float = 0.,
|
160 |
+
norm_layer: nn.Module = RMSNorm,
|
161 |
+
) -> None:
|
162 |
+
super().__init__()
|
163 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
164 |
+
|
165 |
+
self.dim = dim
|
166 |
+
self.num_heads = num_heads
|
167 |
+
self.head_dim = dim // num_heads
|
168 |
+
self.scale = self.head_dim ** -0.5
|
169 |
+
|
170 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
171 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
172 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
173 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
174 |
+
self.proj = nn.Linear(dim, dim)
|
175 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
176 |
+
|
177 |
+
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
178 |
+
B, N, C = x.shape
|
179 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
180 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
181 |
+
q = self.q_norm(q)
|
182 |
+
k = self.k_norm(k)
|
183 |
+
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
184 |
+
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
185 |
+
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
186 |
+
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
187 |
+
|
188 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
189 |
+
|
190 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
191 |
+
x = self.proj(x)
|
192 |
+
x = self.proj_drop(x)
|
193 |
+
return x
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
class DDTBlock(nn.Module):
|
198 |
+
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
199 |
+
super().__init__()
|
200 |
+
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
201 |
+
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
202 |
+
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
203 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
204 |
+
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
205 |
+
self.adaLN_modulation = nn.Sequential(
|
206 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, x, c, pos, mask=None):
|
210 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
211 |
+
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
212 |
+
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
class DDT(nn.Module):
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
in_channels=4,
|
220 |
+
num_groups=12,
|
221 |
+
hidden_size=1152,
|
222 |
+
num_blocks=18,
|
223 |
+
num_encoder_blocks=4,
|
224 |
+
patch_size=2,
|
225 |
+
num_classes=1000,
|
226 |
+
learn_sigma=True,
|
227 |
+
deep_supervision=0,
|
228 |
+
weight_path=None,
|
229 |
+
load_ema=False,
|
230 |
+
):
|
231 |
+
super().__init__()
|
232 |
+
self.deep_supervision = deep_supervision
|
233 |
+
self.learn_sigma = learn_sigma
|
234 |
+
self.in_channels = in_channels
|
235 |
+
self.out_channels = in_channels
|
236 |
+
self.hidden_size = hidden_size
|
237 |
+
self.num_groups = num_groups
|
238 |
+
self.num_blocks = num_blocks
|
239 |
+
self.num_encoder_blocks = num_encoder_blocks
|
240 |
+
self.patch_size = patch_size
|
241 |
+
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
242 |
+
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
243 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
244 |
+
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
245 |
+
|
246 |
+
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
247 |
+
|
248 |
+
self.weight_path = weight_path
|
249 |
+
|
250 |
+
self.load_ema = load_ema
|
251 |
+
self.blocks = nn.ModuleList([
|
252 |
+
DDTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
253 |
+
])
|
254 |
+
self.initialize_weights()
|
255 |
+
self.precompute_pos = dict()
|
256 |
+
|
257 |
+
def fetch_pos(self, height, width, device):
|
258 |
+
if (height, width) in self.precompute_pos:
|
259 |
+
return self.precompute_pos[(height, width)].to(device)
|
260 |
+
else:
|
261 |
+
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
262 |
+
self.precompute_pos[(height, width)] = pos
|
263 |
+
return pos
|
264 |
+
|
265 |
+
def initialize_weights(self):
|
266 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
267 |
+
w = self.x_embedder.proj.weight.data
|
268 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
269 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
270 |
+
|
271 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
272 |
+
w = self.s_embedder.proj.weight.data
|
273 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
274 |
+
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
275 |
+
|
276 |
+
# Initialize label embedding table:
|
277 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
278 |
+
|
279 |
+
# Initialize timestep embedding MLP:
|
280 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
281 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
282 |
+
|
283 |
+
# Zero-out output layers:
|
284 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
285 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
286 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
287 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
288 |
+
|
289 |
+
|
290 |
+
def forward(self, x, t, y, s=None, mask=None):
|
291 |
+
B, _, H, W = x.shape
|
292 |
+
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
293 |
+
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
294 |
+
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
295 |
+
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
296 |
+
c = nn.functional.silu(t + y)
|
297 |
+
if s is None:
|
298 |
+
s = self.s_embedder(x)
|
299 |
+
for i in range(self.num_encoder_blocks):
|
300 |
+
s = self.blocks[i](s, c, pos, mask)
|
301 |
+
s = nn.functional.silu(t + s)
|
302 |
+
|
303 |
+
x = self.x_embedder(x)
|
304 |
+
for i in range(self.num_encoder_blocks, self.num_blocks):
|
305 |
+
x = self.blocks[i](x, s, pos, None)
|
306 |
+
x = self.final_layer(x, s)
|
307 |
+
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
308 |
+
return x, s
|
src/models/denoiser/improved_dit.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from typing import Tuple
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import math
|
6 |
+
|
7 |
+
from torch.nn.init import zeros_
|
8 |
+
from torch.nn.modules.module import T
|
9 |
+
|
10 |
+
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
11 |
+
from torch.nn.functional import scaled_dot_product_attention
|
12 |
+
|
13 |
+
|
14 |
+
def modulate(x, shift, scale):
|
15 |
+
return x * (1 + scale) + shift
|
16 |
+
|
17 |
+
class Embed(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_chans: int = 3,
|
21 |
+
embed_dim: int = 768,
|
22 |
+
norm_layer = None,
|
23 |
+
bias: bool = True,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.in_chans = in_chans
|
27 |
+
self.embed_dim = embed_dim
|
28 |
+
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
29 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.proj(x)
|
32 |
+
x = self.norm(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
class TimestepEmbedder(nn.Module):
|
36 |
+
|
37 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
38 |
+
super().__init__()
|
39 |
+
self.mlp = nn.Sequential(
|
40 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
41 |
+
nn.SiLU(),
|
42 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
43 |
+
)
|
44 |
+
self.frequency_embedding_size = frequency_embedding_size
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def timestep_embedding(t, dim, max_period=10):
|
48 |
+
half = dim // 2
|
49 |
+
freqs = torch.exp(
|
50 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
51 |
+
)
|
52 |
+
args = t[..., None].float() * freqs[None, ...]
|
53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
54 |
+
if dim % 2:
|
55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
56 |
+
return embedding
|
57 |
+
|
58 |
+
def forward(self, t):
|
59 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
60 |
+
t_emb = self.mlp(t_freq)
|
61 |
+
return t_emb
|
62 |
+
|
63 |
+
class LabelEmbedder(nn.Module):
|
64 |
+
def __init__(self, num_classes, hidden_size):
|
65 |
+
super().__init__()
|
66 |
+
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
67 |
+
self.num_classes = num_classes
|
68 |
+
|
69 |
+
def forward(self, labels,):
|
70 |
+
embeddings = self.embedding_table(labels)
|
71 |
+
return embeddings
|
72 |
+
|
73 |
+
class FinalLayer(nn.Module):
|
74 |
+
def __init__(self, hidden_size, out_channels):
|
75 |
+
super().__init__()
|
76 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
77 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
78 |
+
self.adaLN_modulation = nn.Sequential(
|
79 |
+
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, x, c):
|
83 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
84 |
+
x = modulate(self.norm_final(x), shift, scale)
|
85 |
+
x = self.linear(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
class RMSNorm(nn.Module):
|
89 |
+
def __init__(self, hidden_size, eps=1e-6):
|
90 |
+
"""
|
91 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
95 |
+
self.variance_epsilon = eps
|
96 |
+
|
97 |
+
def forward(self, hidden_states):
|
98 |
+
input_dtype = hidden_states.dtype
|
99 |
+
hidden_states = hidden_states.to(torch.float32)
|
100 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
101 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
102 |
+
return (self.weight * hidden_states).to(input_dtype)
|
103 |
+
|
104 |
+
class FeedForward(nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
dim: int,
|
108 |
+
hidden_dim: int,
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
112 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
113 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
114 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
115 |
+
def forward(self, x):
|
116 |
+
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
117 |
+
return x
|
118 |
+
|
119 |
+
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
120 |
+
# assert H * H == end
|
121 |
+
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
122 |
+
x_pos = torch.linspace(0, scale, width)
|
123 |
+
y_pos = torch.linspace(0, scale, height)
|
124 |
+
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
125 |
+
y_pos = y_pos.reshape(-1)
|
126 |
+
x_pos = x_pos.reshape(-1)
|
127 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
128 |
+
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
129 |
+
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
130 |
+
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
131 |
+
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
132 |
+
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
133 |
+
freqs_cis = freqs_cis.reshape(height*width, -1)
|
134 |
+
return freqs_cis
|
135 |
+
|
136 |
+
|
137 |
+
def apply_rotary_emb(
|
138 |
+
xq: torch.Tensor,
|
139 |
+
xk: torch.Tensor,
|
140 |
+
freqs_cis: torch.Tensor,
|
141 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
142 |
+
freqs_cis = freqs_cis[None, :, None, :]
|
143 |
+
# xq : B N H Hc
|
144 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
145 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
146 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
147 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
148 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
149 |
+
|
150 |
+
|
151 |
+
class RAttention(nn.Module):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
dim: int,
|
155 |
+
num_heads: int = 8,
|
156 |
+
qkv_bias: bool = False,
|
157 |
+
qk_norm: bool = True,
|
158 |
+
attn_drop: float = 0.,
|
159 |
+
proj_drop: float = 0.,
|
160 |
+
norm_layer: nn.Module = RMSNorm,
|
161 |
+
) -> None:
|
162 |
+
super().__init__()
|
163 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
164 |
+
|
165 |
+
self.dim = dim
|
166 |
+
self.num_heads = num_heads
|
167 |
+
self.head_dim = dim // num_heads
|
168 |
+
self.scale = self.head_dim ** -0.5
|
169 |
+
|
170 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
171 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
172 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
173 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
174 |
+
self.proj = nn.Linear(dim, dim)
|
175 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
176 |
+
|
177 |
+
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
178 |
+
B, N, C = x.shape
|
179 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
180 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
181 |
+
q = self.q_norm(q)
|
182 |
+
k = self.k_norm(k)
|
183 |
+
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
184 |
+
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
185 |
+
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
186 |
+
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
187 |
+
|
188 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
189 |
+
|
190 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
191 |
+
x = self.proj(x)
|
192 |
+
x = self.proj_drop(x)
|
193 |
+
return x
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
class DiTBlock(nn.Module):
|
198 |
+
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
199 |
+
super().__init__()
|
200 |
+
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
201 |
+
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
202 |
+
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
203 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
204 |
+
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
205 |
+
self.adaLN_modulation = nn.Sequential(
|
206 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, x, c, pos, mask=None):
|
210 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
211 |
+
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
212 |
+
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
class DiT(nn.Module):
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
in_channels=4,
|
220 |
+
num_groups=12,
|
221 |
+
hidden_size=1152,
|
222 |
+
num_blocks=18,
|
223 |
+
patch_size=2,
|
224 |
+
num_classes=1000,
|
225 |
+
learn_sigma=True,
|
226 |
+
deep_supervision=0,
|
227 |
+
weight_path=None,
|
228 |
+
load_ema=False,
|
229 |
+
):
|
230 |
+
super().__init__()
|
231 |
+
self.deep_supervision = deep_supervision
|
232 |
+
self.learn_sigma = learn_sigma
|
233 |
+
self.in_channels = in_channels
|
234 |
+
self.out_channels = in_channels
|
235 |
+
self.hidden_size = hidden_size
|
236 |
+
self.num_groups = num_groups
|
237 |
+
self.num_blocks = num_blocks
|
238 |
+
self.patch_size = patch_size
|
239 |
+
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
240 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
241 |
+
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
242 |
+
|
243 |
+
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
244 |
+
|
245 |
+
self.weight_path = weight_path
|
246 |
+
|
247 |
+
self.load_ema = load_ema
|
248 |
+
self.blocks = nn.ModuleList([
|
249 |
+
DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
250 |
+
])
|
251 |
+
self.initialize_weights()
|
252 |
+
self.precompute_pos = dict()
|
253 |
+
|
254 |
+
def fetch_pos(self, height, width, device, dtype):
|
255 |
+
if (height, width) in self.precompute_pos:
|
256 |
+
return self.precompute_pos[(height, width)].to(device, dtype)
|
257 |
+
else:
|
258 |
+
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
|
259 |
+
self.precompute_pos[(height, width)] = pos
|
260 |
+
return pos
|
261 |
+
|
262 |
+
def initialize_weights(self):
|
263 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
264 |
+
w = self.x_embedder.proj.weight.data
|
265 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
266 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
267 |
+
|
268 |
+
# Initialize label embedding table:
|
269 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
270 |
+
|
271 |
+
# Initialize timestep embedding MLP:
|
272 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
273 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
274 |
+
|
275 |
+
# Zero-out output layers:
|
276 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
277 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
278 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
279 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
280 |
+
|
281 |
+
def forward(self, x, t, y, masks=None):
|
282 |
+
if masks is None:
|
283 |
+
masks = [None, ]*self.num_blocks
|
284 |
+
if isinstance(masks, torch.Tensor):
|
285 |
+
masks = masks.unbind(0)
|
286 |
+
if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
|
287 |
+
masks = masks + [None]*(self.num_blocks-len(masks))
|
288 |
+
|
289 |
+
B, _, H, W = x.shape
|
290 |
+
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
291 |
+
x = self.x_embedder(x)
|
292 |
+
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
|
293 |
+
B, L, C = x.shape
|
294 |
+
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
295 |
+
y = self.y_embedder(y).view(B, 1, C)
|
296 |
+
condition = nn.functional.silu(t + y)
|
297 |
+
for i, block in enumerate(self.blocks):
|
298 |
+
x = block(x, condition, pos, masks[i])
|
299 |
+
x = self.final_layer(x, condition)
|
300 |
+
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
301 |
+
return x
|
src/models/encoder.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import os
|
4 |
+
import timm
|
5 |
+
import transformers
|
6 |
+
import torch.nn as nn
|
7 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
8 |
+
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
|
11 |
+
class RandViT(nn.Module):
|
12 |
+
def __init__(self, model_id, weight_path:str=None):
|
13 |
+
super(RandViT, self).__init__()
|
14 |
+
self.encoder = timm.create_model(
|
15 |
+
model_id,
|
16 |
+
num_classes=0,
|
17 |
+
)
|
18 |
+
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
19 |
+
self.encoder.head = torch.nn.Identity()
|
20 |
+
self.patch_size = self.encoder.patch_embed.patch_size
|
21 |
+
self.shifts = nn.Parameter(torch.tensor([0.0
|
22 |
+
]), requires_grad=False)
|
23 |
+
self.scales = nn.Parameter(torch.tensor([1.0
|
24 |
+
]), requires_grad=False)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
28 |
+
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
29 |
+
b, c, h, w = x.shape
|
30 |
+
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
31 |
+
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
32 |
+
feature = feature.transpose(1, 2)
|
33 |
+
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
34 |
+
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
35 |
+
return feature
|
36 |
+
|
37 |
+
class MAE(nn.Module):
|
38 |
+
def __init__(self, model_id, weight_path:str):
|
39 |
+
super(MAE, self).__init__()
|
40 |
+
if os.path.isdir(weight_path):
|
41 |
+
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
42 |
+
self.encoder = timm.create_model(
|
43 |
+
model_id,
|
44 |
+
checkpoint_path=weight_path,
|
45 |
+
num_classes=0,
|
46 |
+
)
|
47 |
+
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
48 |
+
self.encoder.head = torch.nn.Identity()
|
49 |
+
self.patch_size = self.encoder.patch_embed.patch_size
|
50 |
+
self.shifts = nn.Parameter(torch.tensor([0.0
|
51 |
+
]), requires_grad=False)
|
52 |
+
self.scales = nn.Parameter(torch.tensor([1.0
|
53 |
+
]), requires_grad=False)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
57 |
+
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
58 |
+
b, c, h, w = x.shape
|
59 |
+
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
60 |
+
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
61 |
+
feature = feature.transpose(1, 2)
|
62 |
+
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
63 |
+
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
64 |
+
return feature
|
65 |
+
|
66 |
+
class DINO(nn.Module):
|
67 |
+
def __init__(self, model_id, weight_path:str):
|
68 |
+
super(DINO, self).__init__()
|
69 |
+
if os.path.isdir(weight_path):
|
70 |
+
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
71 |
+
self.encoder = timm.create_model(
|
72 |
+
model_id,
|
73 |
+
checkpoint_path=weight_path,
|
74 |
+
num_classes=0,
|
75 |
+
)
|
76 |
+
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
77 |
+
self.encoder.head = torch.nn.Identity()
|
78 |
+
self.patch_size = self.encoder.patch_embed.patch_size
|
79 |
+
self.shifts = nn.Parameter(torch.tensor([ 0.0,
|
80 |
+
]), requires_grad=False)
|
81 |
+
self.scales = nn.Parameter(torch.tensor([ 1.0,
|
82 |
+
]), requires_grad=False)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
86 |
+
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
87 |
+
b, c, h, w = x.shape
|
88 |
+
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
89 |
+
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
90 |
+
feature = feature.transpose(1, 2)
|
91 |
+
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
92 |
+
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
93 |
+
return feature
|
94 |
+
|
95 |
+
class CLIP(nn.Module):
|
96 |
+
def __init__(self, model_id, weight_path:str):
|
97 |
+
super(CLIP, self).__init__()
|
98 |
+
self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path)
|
99 |
+
self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size
|
100 |
+
self.shifts = nn.Parameter(torch.tensor([0.0,
|
101 |
+
]), requires_grad=False)
|
102 |
+
self.scales = nn.Parameter(torch.tensor([1.0,
|
103 |
+
]), requires_grad=False)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x)
|
107 |
+
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
108 |
+
b, c, h, w = x.shape
|
109 |
+
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
110 |
+
feature = self.encoder(x)['last_hidden_state'][:, 1:]
|
111 |
+
feature = feature.transpose(1, 2)
|
112 |
+
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
113 |
+
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
114 |
+
return feature
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
class DINOv2(nn.Module):
|
119 |
+
def __init__(self, model_id, weight_path:str):
|
120 |
+
super(DINOv2, self).__init__()
|
121 |
+
self.encoder = transformers.Dinov2Model.from_pretrained(weight_path)
|
122 |
+
self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
126 |
+
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
127 |
+
b, c, h, w = x.shape
|
128 |
+
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
129 |
+
feature = self.encoder.forward(x)['last_hidden_state'][:, 1:]
|
130 |
+
feature = feature.transpose(1, 2)
|
131 |
+
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
132 |
+
return feature
|
src/models/vae.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import subprocess
|
3 |
+
import lightning.pytorch as pl
|
4 |
+
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
def class_fn_from_str(class_str):
|
10 |
+
class_module, from_class = class_str.rsplit(".", 1)
|
11 |
+
class_module = __import__(class_module, fromlist=[from_class])
|
12 |
+
return getattr(class_module, from_class)
|
13 |
+
|
14 |
+
|
15 |
+
class BaseVAE(torch.nn.Module):
|
16 |
+
def __init__(self, scale=1.0, shift=0.0):
|
17 |
+
super().__init__()
|
18 |
+
self.model = torch.nn.Identity()
|
19 |
+
self.scale = scale
|
20 |
+
self.shift = shift
|
21 |
+
|
22 |
+
def encode(self, x):
|
23 |
+
return x/self.scale+self.shift
|
24 |
+
|
25 |
+
def decode(self, x):
|
26 |
+
return (x-self.shift)*self.scale
|
27 |
+
|
28 |
+
|
29 |
+
# very bad bugs with nearest sampling
|
30 |
+
class DownSampleVAE(BaseVAE):
|
31 |
+
def __init__(self, down_ratio, scale=1.0, shift=0.0):
|
32 |
+
super().__init__()
|
33 |
+
self.model = torch.nn.Identity()
|
34 |
+
self.scale = scale
|
35 |
+
self.shift = shift
|
36 |
+
self.down_ratio = down_ratio
|
37 |
+
|
38 |
+
def encode(self, x):
|
39 |
+
x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False)
|
40 |
+
return x/self.scale+self.shift
|
41 |
+
|
42 |
+
def decode(self, x):
|
43 |
+
x = (x-self.shift)*self.scale
|
44 |
+
x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
class LatentVAE(BaseVAE):
|
50 |
+
def __init__(self, precompute=False, weight_path:str=None):
|
51 |
+
super().__init__()
|
52 |
+
self.precompute = precompute
|
53 |
+
self.model = None
|
54 |
+
self.weight_path = weight_path
|
55 |
+
|
56 |
+
from diffusers.models import AutoencoderKL
|
57 |
+
setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
|
58 |
+
self.scaling_factor = self.model.config.scaling_factor
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def encode(self, x):
|
62 |
+
assert self.model is not None
|
63 |
+
if self.precompute:
|
64 |
+
return x.mul_(self.scaling_factor)
|
65 |
+
return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def decode(self, x):
|
69 |
+
assert self.model is not None
|
70 |
+
return self.model.decode(x.div_(self.scaling_factor)).sample
|
71 |
+
|
72 |
+
|
73 |
+
def uint82fp(x):
|
74 |
+
x = x.to(torch.float32)
|
75 |
+
x = (x - 127.5) / 127.5
|
76 |
+
return x
|
77 |
+
|
78 |
+
def fp2uint8(x):
|
79 |
+
x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
|
80 |
+
return x
|
81 |
+
|