wangshuai6 commited on
Commit
9e426da
·
1 Parent(s): df79fa7

init space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +170 -0
  2. configs/repa_improved_ddt_xlen22de6_256.yaml +108 -0
  3. configs/repa_improved_ddt_xlen22de6_512.yaml +108 -0
  4. configs/repa_improved_dit_large.yaml +99 -0
  5. configs/repa_improved_dit_xl.yaml +99 -0
  6. requirements.txt +6 -0
  7. src/__init__.py +0 -0
  8. src/callbacks/__init__.py +0 -0
  9. src/callbacks/grad.py +22 -0
  10. src/callbacks/model_checkpoint.py +21 -0
  11. src/callbacks/save_images.py +105 -0
  12. src/callbacks/simple_ema.py +79 -0
  13. src/data/__init__.py +1 -0
  14. src/data/dataset/__init__.py +0 -0
  15. src/data/dataset/celeba.py +11 -0
  16. src/data/dataset/imagenet.py +82 -0
  17. src/data/dataset/metric_dataset.py +82 -0
  18. src/data/dataset/randn.py +41 -0
  19. src/data/var_training.py +145 -0
  20. src/diffusion/__init__.py +0 -0
  21. src/diffusion/base/guidance.py +60 -0
  22. src/diffusion/base/sampling.py +31 -0
  23. src/diffusion/base/scheduling.py +32 -0
  24. src/diffusion/base/training.py +29 -0
  25. src/diffusion/ddpm/ddim_sampling.py +40 -0
  26. src/diffusion/ddpm/scheduling.py +102 -0
  27. src/diffusion/ddpm/training.py +83 -0
  28. src/diffusion/ddpm/vp_sampling.py +59 -0
  29. src/diffusion/flow_matching/adam_sampling.py +107 -0
  30. src/diffusion/flow_matching/sampling.py +179 -0
  31. src/diffusion/flow_matching/scheduling.py +39 -0
  32. src/diffusion/flow_matching/training.py +55 -0
  33. src/diffusion/flow_matching/training_cos.py +59 -0
  34. src/diffusion/flow_matching/training_repa.py +137 -0
  35. src/diffusion/pre_integral.py +143 -0
  36. src/diffusion/stateful_flow_matching/adam_sampling.py +112 -0
  37. src/diffusion/stateful_flow_matching/sampling.py +103 -0
  38. src/diffusion/stateful_flow_matching/scheduling.py +39 -0
  39. src/diffusion/stateful_flow_matching/sharing_sampling.py +149 -0
  40. src/diffusion/stateful_flow_matching/training.py +55 -0
  41. src/diffusion/stateful_flow_matching/training_repa.py +152 -0
  42. src/lightning_data.py +162 -0
  43. src/lightning_model.py +123 -0
  44. src/models/__init__.py +0 -0
  45. src/models/conditioner.py +26 -0
  46. src/models/denoiser/__init__.py +0 -0
  47. src/models/denoiser/decoupled_improved_dit.py +308 -0
  48. src/models/denoiser/improved_dit.py +301 -0
  49. src/models/encoder.py +132 -0
  50. src/models/vae.py +81 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vae:
2
+ # class_path: src.models.vae.LatentVAE
3
+ # init_args:
4
+ # precompute: true
5
+ # weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
6
+ # denoiser:
7
+ # class_path: src.models.denoiser.decoupled_improved_dit.DDT
8
+ # init_args:
9
+ # in_channels: 4
10
+ # patch_size: 2
11
+ # num_groups: 16
12
+ # hidden_size: &hidden_dim 1152
13
+ # num_blocks: 28
14
+ # num_encoder_blocks: 22
15
+ # num_classes: 1000
16
+ # conditioner:
17
+ # class_path: src.models.conditioner.LabelConditioner
18
+ # init_args:
19
+ # null_class: 1000
20
+ # diffusion_sampler:
21
+ # class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
22
+ # init_args:
23
+ # num_steps: 250
24
+ # guidance: 3.0
25
+ # state_refresh_rate: 1
26
+ # guidance_interval_min: 0.3
27
+ # guidance_interval_max: 1.0
28
+ # timeshift: 1.0
29
+ # last_step: 0.04
30
+ # scheduler: *scheduler
31
+ # w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
32
+ # guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
33
+ # step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
34
+
35
+ import torch
36
+ import argparse
37
+ from omegaconf import OmegaConf
38
+ from src.models.vae import fp2uint8
39
+ from src.diffusion.base.guidance import simple_guidance_fn
40
+ from src.diffusion.stateful_flow_matching.sharing_sampling import EulerSampler
41
+ from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler
42
+ from PIL import Image
43
+ import gradio as gr
44
+ from huggingface_hub import snapshot_download
45
+
46
+
47
+ def instantiate_class(config):
48
+ kwargs = config.get("init_args", {})
49
+ class_module, class_name = config["class_path"].rsplit(".", 1)
50
+ module = __import__(class_module, fromlist=[class_name])
51
+ args_class = getattr(module, class_name)
52
+ return args_class(**kwargs)
53
+
54
+ def load_model(weight_dict, denosier):
55
+ prefix = "ema_denoiser."
56
+ for k, v in denoiser.state_dict().items():
57
+ try:
58
+ v.copy_(weight_dict["state_dict"][prefix + k])
59
+ except:
60
+ print(f"Failed to copy {prefix + k} to denoiser weight")
61
+ return denoiser
62
+
63
+
64
+ class Pipeline:
65
+ def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
66
+ self.vae = vae
67
+ self.denoiser = denoiser
68
+ self.conditioner = conditioner
69
+ self.diffusion_sampler = diffusion_sampler
70
+ self.resolution = resolution
71
+
72
+ @torch.no_grad()
73
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
74
+ def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
75
+ self.diffusion_sampler.num_steps = num_steps
76
+ self.diffusion_sampler.guidance = guidance
77
+ self.diffusion_sampler.state_refresh_rate = state_refresh_rate
78
+ self.diffusion_sampler.guidance_interval_min = guidance_interval_min
79
+ self.diffusion_sampler.guidance_interval_max = guidance_interval_max
80
+ self.diffusion_sampler.timeshift = timeshift
81
+ generator = torch.Generator(device="cuda").manual_seed(seed)
82
+ xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
83
+ with torch.no_grad():
84
+ condition, uncondition = conditioner([y,]*num_images)
85
+ # Sample images:
86
+ samples = diffusion_sampler(denoiser, xT, condition, uncondition)
87
+ samples = vae.decode(samples)
88
+ # fp32 -1,1 -> uint8 0,255
89
+ samples = fp2uint8(samples)
90
+ samples = samples.permute(0, 2, 3, 1).cpu().numpy()
91
+ images = []
92
+ for i in range(num_images):
93
+ image = Image.fromarray(samples[i])
94
+ images.append(image)
95
+ return images
96
+
97
+ import os
98
+ import spaces
99
+ if __name__ == "__main__":
100
+ parser = argparse.ArgumentParser()
101
+ parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml")
102
+ parser.add_argument("--resolution", type=int, default=512)
103
+ parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R512")
104
+ parser.add_argument("--ckpt_path", type=str, default="models")
105
+ args = parser.parse_args()
106
+
107
+ if not os.path.exists(args.ckpt_path):
108
+ snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path)
109
+
110
+ config = OmegaConf.load(args.config)
111
+ vae_config = config.model.vae
112
+ diffusion_sampler_config = config.model.diffusion_sampler
113
+ denoiser_config = config.model.denoiser
114
+ conditioner_config = config.model.conditioner
115
+
116
+ vae = instantiate_class(vae_config)
117
+ denoiser = instantiate_class(denoiser_config)
118
+ conditioner = instantiate_class(conditioner_config)
119
+
120
+
121
+ diffusion_sampler = EulerSampler(
122
+ scheduler=LinearScheduler(),
123
+ w_scheduler=LinearScheduler(),
124
+ guidance_fn=simple_guidance_fn,
125
+ num_steps=50,
126
+ guidance=3.0,
127
+ state_refresh_rate=1,
128
+ guidance_interval_min=0.3,
129
+ guidance_interval_max=1.0,
130
+ timeshift=1.0
131
+ )
132
+ ckpt_path = os.path.join(args.ckpt_path, "model.ckpt")
133
+ ckpt = torch.load(ckpt_path, map_location="cpu")
134
+ denoiser = load_model(ckpt, denoiser)
135
+ denoiser = denoiser.cuda()
136
+ vae = vae.cuda()
137
+ denoiser.eval()
138
+
139
+ pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution)
140
+
141
+ with gr.Blocks() as demo:
142
+ gr.Markdown("DDT")
143
+ with gr.Row():
144
+ with gr.Column(scale=1):
145
+ num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
146
+ guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
147
+ num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
148
+ label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948)
149
+ seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
150
+ state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
151
+ guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
152
+ guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0)
153
+ timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
154
+ with gr.Column(scale=2):
155
+ btn = gr.Button("Generate")
156
+ output = gr.Gallery(label="Images")
157
+
158
+ btn.click(fn=pipeline,
159
+ inputs=[
160
+ label,
161
+ num_images,
162
+ seed,
163
+ num_steps,
164
+ guidance,
165
+ state_refresh_rate,
166
+ guidance_interval_min,
167
+ guidance_interval_max,
168
+ timeshift
169
+ ], outputs=[output])
170
+ demo.launch(server_name="0.0.0.0", server_port=7861)
configs/repa_improved_ddt_xlen22de6_256.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.4.0
2
+ seed_everything: true
3
+ tags:
4
+ exp: &exp repa_flatten_condit22_dit6_fixt_xl
5
+ torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
6
+ huggingface_cache_dir: null
7
+ trainer:
8
+ default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
9
+ accelerator: auto
10
+ strategy: auto
11
+ devices: auto
12
+ num_nodes: 1
13
+ precision: bf16-mixed
14
+ logger:
15
+ class_path: lightning.pytorch.loggers.WandbLogger
16
+ init_args:
17
+ project: universal_flow
18
+ name: *exp
19
+ num_sanity_val_steps: 0
20
+ max_steps: 4000000
21
+ val_check_interval: 4000000
22
+ check_val_every_n_epoch: null
23
+ log_every_n_steps: 50
24
+ deterministic: null
25
+ inference_mode: true
26
+ use_distributed_sampler: false
27
+ callbacks:
28
+ - class_path: src.callbacks.model_checkpoint.CheckpointHook
29
+ init_args:
30
+ every_n_train_steps: 10000
31
+ save_top_k: -1
32
+ save_last: true
33
+ - class_path: src.callbacks.save_images.SaveImagesHook
34
+ init_args:
35
+ save_dir: val
36
+ plugins:
37
+ - src.plugins.bd_env.BDEnvironment
38
+ model:
39
+ vae:
40
+ class_path: src.models.vae.LatentVAE
41
+ init_args:
42
+ precompute: true
43
+ weight_path: stabilityai/sd-vae-ft-ema
44
+ denoiser:
45
+ class_path: src.models.denoiser.decoupled_improved_dit.DDT
46
+ init_args:
47
+ in_channels: 4
48
+ patch_size: 2
49
+ num_groups: 16
50
+ hidden_size: &hidden_dim 1152
51
+ num_blocks: 28
52
+ num_encoder_blocks: 22
53
+ num_classes: 1000
54
+ conditioner:
55
+ class_path: src.models.conditioner.LabelConditioner
56
+ init_args:
57
+ null_class: 1000
58
+ diffusion_trainer:
59
+ class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
60
+ init_args:
61
+ lognorm_t: true
62
+ encoder_weight_path: dinov2_vitb14
63
+ align_layer: 8
64
+ proj_denoiser_dim: *hidden_dim
65
+ proj_hidden_dim: *hidden_dim
66
+ proj_encoder_dim: 768
67
+ scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
68
+ diffusion_sampler:
69
+ class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
70
+ init_args:
71
+ num_steps: 250
72
+ guidance: 2.0
73
+ timeshift: 1.0
74
+ state_refresh_rate: 1
75
+ guidance_interval_min: 0.3
76
+ guidance_interval_max: 1.0
77
+ scheduler: *scheduler
78
+ w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
79
+ guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
80
+ last_step: 0.04
81
+ step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
82
+ ema_tracker:
83
+ class_path: src.callbacks.simple_ema.SimpleEMA
84
+ init_args:
85
+ decay: 0.9999
86
+ optimizer:
87
+ class_path: torch.optim.AdamW
88
+ init_args:
89
+ lr: 1e-4
90
+ betas:
91
+ - 0.9
92
+ - 0.95
93
+ weight_decay: 0.0
94
+ data:
95
+ train_dataset: imagenet256
96
+ train_root: /mnt/bn/wangshuai6/data/ImageNet/train
97
+ train_image_size: 256
98
+ train_batch_size: 16
99
+ eval_max_num_instances: 50000
100
+ pred_batch_size: 64
101
+ pred_num_workers: 4
102
+ pred_seeds: null
103
+ pred_selected_classes: null
104
+ num_classes: 1000
105
+ latent_shape:
106
+ - 4
107
+ - 32
108
+ - 32
configs/repa_improved_ddt_xlen22de6_512.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.4.0
2
+ seed_everything: true
3
+ tags:
4
+ exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl
5
+ torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
6
+ huggingface_cache_dir: null
7
+ trainer:
8
+ default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
9
+ accelerator: auto
10
+ strategy: auto
11
+ devices: auto
12
+ num_nodes: 1
13
+ precision: bf16-mixed
14
+ logger:
15
+ class_path: lightning.pytorch.loggers.WandbLogger
16
+ init_args:
17
+ project: universal_flow
18
+ name: *exp
19
+ num_sanity_val_steps: 0
20
+ max_steps: 4000000
21
+ val_check_interval: 4000000
22
+ check_val_every_n_epoch: null
23
+ log_every_n_steps: 50
24
+ deterministic: null
25
+ inference_mode: true
26
+ use_distributed_sampler: false
27
+ callbacks:
28
+ - class_path: src.callbacks.model_checkpoint.CheckpointHook
29
+ init_args:
30
+ every_n_train_steps: 10000
31
+ save_top_k: -1
32
+ save_last: true
33
+ - class_path: src.callbacks.save_images.SaveImagesHook
34
+ init_args:
35
+ save_dir: val
36
+ plugins:
37
+ - src.plugins.bd_env.BDEnvironment
38
+ model:
39
+ vae:
40
+ class_path: src.models.vae.LatentVAE
41
+ init_args:
42
+ precompute: true
43
+ weight_path: stabilityai/sd-vae-ft-ema
44
+ denoiser:
45
+ class_path: src.models.denoiser.decoupled_improved_dit.DDT
46
+ init_args:
47
+ in_channels: 4
48
+ patch_size: 2
49
+ num_groups: 16
50
+ hidden_size: &hidden_dim 1152
51
+ num_blocks: 28
52
+ num_encoder_blocks: 22
53
+ num_classes: 1000
54
+ conditioner:
55
+ class_path: src.models.conditioner.LabelConditioner
56
+ init_args:
57
+ null_class: 1000
58
+ diffusion_trainer:
59
+ class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
60
+ init_args:
61
+ lognorm_t: true
62
+ encoder_weight_path: dinov2_vitb14
63
+ align_layer: 8
64
+ proj_denoiser_dim: *hidden_dim
65
+ proj_hidden_dim: *hidden_dim
66
+ proj_encoder_dim: 768
67
+ scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
68
+ diffusion_sampler:
69
+ class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
70
+ init_args:
71
+ num_steps: 250
72
+ guidance: 3.0
73
+ state_refresh_rate: 1
74
+ guidance_interval_min: 0.3
75
+ guidance_interval_max: 1.0
76
+ timeshift: 1.0
77
+ last_step: 0.04
78
+ scheduler: *scheduler
79
+ w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
80
+ guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
81
+ step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
82
+ ema_tracker:
83
+ class_path: src.callbacks.simple_ema.SimpleEMA
84
+ init_args:
85
+ decay: 0.9999
86
+ optimizer:
87
+ class_path: torch.optim.AdamW
88
+ init_args:
89
+ lr: 1e-4
90
+ betas:
91
+ - 0.9
92
+ - 0.95
93
+ weight_decay: 0.0
94
+ data:
95
+ train_dataset: imagenet512
96
+ train_root: /mnt/bn/wangshuai6/data/ImageNet/train
97
+ train_image_size: 512
98
+ train_batch_size: 16
99
+ eval_max_num_instances: 50000
100
+ pred_batch_size: 32
101
+ pred_num_workers: 4
102
+ pred_seeds: null
103
+ pred_selected_classes: null
104
+ num_classes: 1000
105
+ latent_shape:
106
+ - 4
107
+ - 64
108
+ - 64
configs/repa_improved_dit_large.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.4.0
2
+ seed_everything: true
3
+ tags:
4
+ exp: &exp repa_improved_dit_large
5
+ torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
6
+ huggingface_cache_dir: null
7
+ trainer:
8
+ default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
9
+ accelerator: auto
10
+ strategy: auto
11
+ devices: auto
12
+ num_nodes: 1
13
+ precision: bf16-mixed
14
+ logger:
15
+ class_path: lightning.pytorch.loggers.WandbLogger
16
+ init_args:
17
+ project: universal_flow
18
+ name: *exp
19
+ num_sanity_val_steps: 0
20
+ max_steps: 400000
21
+ val_check_interval: 100000
22
+ check_val_every_n_epoch: null
23
+ log_every_n_steps: 50
24
+ deterministic: null
25
+ inference_mode: true
26
+ use_distributed_sampler: false
27
+ callbacks:
28
+ - class_path: src.callbacks.model_checkpoint.CheckpointHook
29
+ init_args:
30
+ every_n_train_steps: 10000
31
+ save_top_k: -1
32
+ save_last: true
33
+ - class_path: src.callbacks.save_images.SaveImagesHook
34
+ init_args:
35
+ save_dir: val
36
+ plugins:
37
+ - src.plugins.bd_env.BDEnvironment
38
+ model:
39
+ vae:
40
+ class_path: src.models.vae.LatentVAE
41
+ init_args:
42
+ precompute: true
43
+ weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
44
+ denoiser:
45
+ class_path: src.models.denoiser.improved_dit.DiT
46
+ init_args:
47
+ in_channels: 4
48
+ patch_size: 2
49
+ num_groups: 16
50
+ hidden_size: &hidden_dim 1024
51
+ num_blocks: 24
52
+ num_classes: 1000
53
+ conditioner:
54
+ class_path: src.models.conditioner.LabelConditioner
55
+ init_args:
56
+ null_class: 1000
57
+ diffusion_trainer:
58
+ class_path: src.diffusion.flow_matching.training_repa.REPATrainer
59
+ init_args:
60
+ lognorm_t: true
61
+ encoder_weight_path: dinov2_vitb14
62
+ align_layer: 8
63
+ proj_denoiser_dim: *hidden_dim
64
+ proj_hidden_dim: *hidden_dim
65
+ proj_encoder_dim: 768
66
+ scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
67
+ diffusion_sampler:
68
+ class_path: src.diffusion.flow_matching.sampling.EulerSampler
69
+ init_args:
70
+ num_steps: 250
71
+ guidance: 1.00
72
+ scheduler: *scheduler
73
+ w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
74
+ guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
75
+ step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
76
+ ema_tracker:
77
+ class_path: src.callbacks.simple_ema.SimpleEMA
78
+ init_args:
79
+ decay: 0.9999
80
+ optimizer:
81
+ class_path: torch.optim.AdamW
82
+ init_args:
83
+ lr: 1e-4
84
+ weight_decay: 0.0
85
+ data:
86
+ train_dataset: imagenet256
87
+ train_root: /mnt/bn/wangshuai6/data/ImageNet/train
88
+ train_image_size: 256
89
+ train_batch_size: 32
90
+ eval_max_num_instances: 50000
91
+ pred_batch_size: 64
92
+ pred_num_workers: 4
93
+ pred_seeds: null
94
+ pred_selected_classes: null
95
+ num_classes: 1000
96
+ latent_shape:
97
+ - 4
98
+ - 32
99
+ - 32
configs/repa_improved_dit_xl.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.4.0
2
+ seed_everything: true
3
+ tags:
4
+ exp: &exp repa_improved_dit_xlen22de6_512
5
+ torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
6
+ huggingface_cache_dir: null
7
+ trainer:
8
+ default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
9
+ accelerator: auto
10
+ strategy: auto
11
+ devices: auto
12
+ num_nodes: 1
13
+ precision: bf16-mixed
14
+ logger:
15
+ class_path: lightning.pytorch.loggers.WandbLogger
16
+ init_args:
17
+ project: universal_flow
18
+ name: *exp
19
+ num_sanity_val_steps: 0
20
+ max_steps: 400000
21
+ val_check_interval: 100000
22
+ check_val_every_n_epoch: null
23
+ log_every_n_steps: 50
24
+ deterministic: null
25
+ inference_mode: true
26
+ use_distributed_sampler: false
27
+ callbacks:
28
+ - class_path: src.callbacks.model_checkpoint.CheckpointHook
29
+ init_args:
30
+ every_n_train_steps: 10000
31
+ save_top_k: -1
32
+ save_last: true
33
+ - class_path: src.callbacks.save_images.SaveImagesHook
34
+ init_args:
35
+ save_dir: val
36
+ plugins:
37
+ - src.plugins.bd_env.BDEnvironment
38
+ model:
39
+ vae:
40
+ class_path: src.models.vae.LatentVAE
41
+ init_args:
42
+ precompute: true
43
+ weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
44
+ denoiser:
45
+ class_path: src.models.denoiser.improved_dit.DiT
46
+ init_args:
47
+ in_channels: 4
48
+ patch_size: 2
49
+ num_groups: 16
50
+ hidden_size: &hidden_dim 1152
51
+ num_blocks: 28
52
+ num_classes: 1000
53
+ conditioner:
54
+ class_path: src.models.conditioner.LabelConditioner
55
+ init_args:
56
+ null_class: 1000
57
+ diffusion_trainer:
58
+ class_path: src.diffusion.flow_matching.training_repa.REPATrainer
59
+ init_args:
60
+ lognorm_t: true
61
+ encoder_weight_path: dinov2_vitb14
62
+ align_layer: 8
63
+ proj_denoiser_dim: *hidden_dim
64
+ proj_hidden_dim: *hidden_dim
65
+ proj_encoder_dim: 768
66
+ scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
67
+ diffusion_sampler:
68
+ class_path: src.diffusion.flow_matching.sampling.EulerSampler
69
+ init_args:
70
+ num_steps: 250
71
+ guidance: 1.00
72
+ scheduler: *scheduler
73
+ w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
74
+ guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
75
+ step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
76
+ ema_tracker:
77
+ class_path: src.callbacks.simple_ema.SimpleEMA
78
+ init_args:
79
+ decay: 0.9999
80
+ optimizer:
81
+ class_path: torch.optim.AdamW
82
+ init_args:
83
+ lr: 1e-4
84
+ weight_decay: 0.0
85
+ data:
86
+ train_dataset: imagenet256
87
+ train_root: /mnt/bn/wangshuai6/data/ImageNet/train
88
+ train_image_size: 256
89
+ train_batch_size: 32
90
+ eval_max_num_instances: 50000
91
+ pred_batch_size: 64
92
+ pred_num_workers: 4
93
+ pred_seeds: null
94
+ pred_selected_classes: null
95
+ num_classes: 1000
96
+ latent_shape:
97
+ - 4
98
+ - 32
99
+ - 32
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ lightning==2.5.0.post0
2
+ omegaconf==2.3.0
3
+ torch==2.3.0
4
+ diffusers==0.30.0
5
+ jsonargparse[signatures]>=4.27.7
6
+ accelerate
src/__init__.py ADDED
File without changes
src/callbacks/__init__.py ADDED
File without changes
src/callbacks/grad.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lightning.pytorch as pl
3
+ from lightning.pytorch.utilities import grad_norm
4
+ from torch.optim import Optimizer
5
+
6
+ class GradientMonitor(pl.Callback):
7
+ """Logs the gradient norm"""
8
+
9
+ def __init__(self, norm_type: int = 2):
10
+ norm_type = float(norm_type)
11
+ if norm_type <= 0:
12
+ raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
13
+ self.norm_type = norm_type
14
+
15
+ def on_before_optimizer_step(
16
+ self, trainer: "pl.Trainer",
17
+ pl_module: "pl.LightningModule",
18
+ optimizer: Optimizer
19
+ ) -> None:
20
+ norms = grad_norm(pl_module, norm_type=self.norm_type)
21
+ max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
22
+ pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]})
src/callbacks/model_checkpoint.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Optional, Dict, Any
3
+
4
+ import lightning.pytorch as pl
5
+ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
6
+ from soupsieve.util import lower
7
+
8
+
9
+ class CheckpointHook(ModelCheckpoint):
10
+ """Save checkpoint with only the incremental part of the model"""
11
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
12
+ self.dirpath = trainer.default_root_dir
13
+ self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
14
+ pl_module.strict_loading = False
15
+
16
+ def on_save_checkpoint(
17
+ self, trainer: "pl.Trainer",
18
+ pl_module: "pl.LightningModule",
19
+ checkpoint: Dict[str, Any]
20
+ ) -> None:
21
+ del checkpoint["callbacks"]
src/callbacks/save_images.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning.pytorch as pl
2
+ from lightning.pytorch import Callback
3
+
4
+
5
+ import os.path
6
+ import numpy
7
+ from PIL import Image
8
+ from typing import Sequence, Any, Dict
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
12
+ from lightning_utilities.core.rank_zero import rank_zero_info
13
+
14
+ def process_fn(image, path):
15
+ Image.fromarray(image).save(path)
16
+
17
+ class SaveImagesHook(Callback):
18
+ def __init__(self, save_dir="val", max_save_num=0, compressed=True):
19
+ self.save_dir = save_dir
20
+ self.max_save_num = max_save_num
21
+ self.compressed = compressed
22
+
23
+ def save_start(self, target_dir):
24
+ self.target_dir = target_dir
25
+ self.executor_pool = ThreadPoolExecutor(max_workers=8)
26
+ if not os.path.exists(self.target_dir):
27
+ os.makedirs(self.target_dir, exist_ok=True)
28
+ else:
29
+ if os.listdir(target_dir) and "debug" not in str(target_dir):
30
+ raise FileExistsError(f'{self.target_dir} already exists and not empty!')
31
+ self.samples = []
32
+ self._have_saved_num = 0
33
+ rank_zero_info(f"Save images to {self.target_dir}")
34
+
35
+ def save_image(self, images, filenames):
36
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
37
+ for sample, filename in zip(images, filenames):
38
+ if isinstance(filename, Sequence):
39
+ filename = filename[0]
40
+ path = f'{self.target_dir}/{filename}'
41
+ if self._have_saved_num >= self.max_save_num:
42
+ break
43
+ self.executor_pool.submit(process_fn, sample, path)
44
+ self._have_saved_num += 1
45
+
46
+ def process_batch(
47
+ self,
48
+ trainer: "pl.Trainer",
49
+ pl_module: "pl.LightningModule",
50
+ samples: STEP_OUTPUT,
51
+ batch: Any,
52
+ ) -> None:
53
+ b, c, h, w = samples.shape
54
+ xT, y, metadata = batch
55
+ all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
56
+ self.save_image(samples, metadata)
57
+ if trainer.is_global_zero:
58
+ all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
59
+ self.samples.append(all_samples)
60
+
61
+ def save_end(self):
62
+ if self.compressed and len(self.samples) > 0:
63
+ samples = numpy.concatenate(self.samples)
64
+ numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
65
+ self.executor_pool.shutdown(wait=True)
66
+ self.samples = []
67
+ self.target_dir = None
68
+ self._have_saved_num = 0
69
+ self.executor_pool = None
70
+
71
+ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
72
+ target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
73
+ self.save_start(target_dir)
74
+
75
+ def on_validation_batch_end(
76
+ self,
77
+ trainer: "pl.Trainer",
78
+ pl_module: "pl.LightningModule",
79
+ outputs: STEP_OUTPUT,
80
+ batch: Any,
81
+ batch_idx: int,
82
+ dataloader_idx: int = 0,
83
+ ) -> None:
84
+ return self.process_batch(trainer, pl_module, outputs, batch)
85
+
86
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
87
+ self.save_end()
88
+
89
+ def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
90
+ target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
91
+ self.save_start(target_dir)
92
+
93
+ def on_predict_batch_end(
94
+ self,
95
+ trainer: "pl.Trainer",
96
+ pl_module: "pl.LightningModule",
97
+ samples: Any,
98
+ batch: Any,
99
+ batch_idx: int,
100
+ dataloader_idx: int = 0,
101
+ ) -> None:
102
+ return self.process_batch(trainer, pl_module, samples, batch)
103
+
104
+ def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
105
+ self.save_end()
src/callbacks/simple_ema.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import threading
6
+ import lightning.pytorch as pl
7
+ from lightning.pytorch import Callback
8
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
9
+
10
+ from src.utils.copy import swap_tensors
11
+
12
+ class SimpleEMA(Callback):
13
+ def __init__(self, net:nn.Module, ema_net:nn.Module,
14
+ decay: float = 0.9999,
15
+ every_n_steps: int = 1,
16
+ eval_original_model:bool = False
17
+ ):
18
+ super().__init__()
19
+ self.decay = decay
20
+ self.every_n_steps = every_n_steps
21
+ self.eval_original_model = eval_original_model
22
+ self._stream = torch.cuda.Stream()
23
+
24
+ self.net_params = list(net.parameters())
25
+ self.ema_params = list(ema_net.parameters())
26
+
27
+ def swap_model(self):
28
+ for ema_p, p, in zip(self.ema_params, self.net_params):
29
+ swap_tensors(ema_p, p)
30
+
31
+ def ema_step(self):
32
+ @torch.no_grad()
33
+ def ema_update(ema_model_tuple, current_model_tuple, decay):
34
+ torch._foreach_mul_(ema_model_tuple, decay)
35
+ torch._foreach_add_(
36
+ ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
37
+ )
38
+
39
+ if self._stream is not None:
40
+ self._stream.wait_stream(torch.cuda.current_stream())
41
+ with torch.cuda.stream(self._stream):
42
+ ema_update(self.ema_params, self.net_params, self.decay)
43
+
44
+
45
+ def on_train_batch_end(
46
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
47
+ ) -> None:
48
+ if trainer.global_step % self.every_n_steps == 0:
49
+ self.ema_step()
50
+
51
+ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
52
+ if not self.eval_original_model:
53
+ self.swap_model()
54
+
55
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
56
+ if not self.eval_original_model:
57
+ self.swap_model()
58
+
59
+ def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
60
+ if not self.eval_original_model:
61
+ self.swap_model()
62
+
63
+ def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
64
+ if not self.eval_original_model:
65
+ self.swap_model()
66
+
67
+
68
+ def state_dict(self) -> Dict[str, Any]:
69
+ return {
70
+ "decay": self.decay,
71
+ "every_n_steps": self.every_n_steps,
72
+ "eval_original_model": self.eval_original_model,
73
+ }
74
+
75
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
76
+ self.decay = state_dict["decay"]
77
+ self.every_n_steps = state_dict["every_n_steps"]
78
+ self.eval_original_model = state_dict["eval_original_model"]
79
+
src/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/data/dataset/__init__.py ADDED
File without changes
src/data/dataset/celeba.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ from torchvision.datasets import CelebA
3
+
4
+
5
+ class LocalDataset(CelebA):
6
+ def __init__(self, root:str, ):
7
+ super(LocalDataset, self).__init__(root, "train")
8
+
9
+ def __getitem__(self, idx):
10
+ data = super().__getitem__(idx)
11
+ return data
src/data/dataset/imagenet.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision.datasets import ImageFolder
4
+ from torchvision.transforms.functional import to_tensor
5
+ from torchvision.transforms import Normalize
6
+
7
+ from src.data.dataset.metric_dataset import CenterCrop
8
+
9
+ class LocalCachedDataset(ImageFolder):
10
+ def __init__(self, root, resolution=256):
11
+ super().__init__(root)
12
+ self.transform = CenterCrop(resolution)
13
+ self.cache_root = None
14
+
15
+ def load_latent(self, latent_path):
16
+ pk_data = torch.load(latent_path)
17
+ mean = pk_data['mean'].to(torch.float32)
18
+ logvar = pk_data['logvar'].to(torch.float32)
19
+ logvar = torch.clamp(logvar, -30.0, 20.0)
20
+ std = torch.exp(0.5 * logvar)
21
+ latent = mean + torch.randn_like(mean) * std
22
+ return latent
23
+
24
+ def __getitem__(self, idx: int):
25
+ image_path, target = self.samples[idx]
26
+ latent_path = image_path.replace(self.root, self.cache_root) + ".pt"
27
+
28
+ raw_image = Image.open(image_path).convert('RGB')
29
+ raw_image = self.transform(raw_image)
30
+ raw_image = to_tensor(raw_image)
31
+ if self.cache_root is not None:
32
+ latent = self.load_latent(latent_path)
33
+ else:
34
+ latent = raw_image
35
+ return raw_image, latent, target
36
+
37
+ class ImageNet256(LocalCachedDataset):
38
+ def __init__(self, root, ):
39
+ super().__init__(root, 256)
40
+ self.cache_root = root + "_256_latent"
41
+
42
+ class ImageNet512(LocalCachedDataset):
43
+ def __init__(self, root, ):
44
+ super().__init__(root, 512)
45
+ self.cache_root = root + "_512_latent"
46
+
47
+ class PixImageNet(ImageFolder):
48
+ def __init__(self, root, resolution=256):
49
+ super().__init__(root)
50
+ self.transform = CenterCrop(resolution)
51
+ self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
52
+
53
+ def __getitem__(self, idx: int):
54
+ image_path, target = self.samples[idx]
55
+ raw_image = Image.open(image_path).convert('RGB')
56
+ raw_image = self.transform(raw_image)
57
+ raw_image = to_tensor(raw_image)
58
+
59
+ normalized_image = self.normalize(raw_image)
60
+ return raw_image, normalized_image, target
61
+
62
+ class PixImageNet64(PixImageNet):
63
+ def __init__(self, root, ):
64
+ super().__init__(root, 64)
65
+
66
+ class PixImageNet128(PixImageNet):
67
+ def __init__(self, root, ):
68
+ super().__init__(root, 128)
69
+
70
+
71
+ class PixImageNet256(PixImageNet):
72
+ def __init__(self, root, ):
73
+ super().__init__(root, 256)
74
+
75
+ class PixImageNet512(PixImageNet):
76
+ def __init__(self, root, ):
77
+ super().__init__(root, 512)
78
+
79
+
80
+
81
+
82
+
src/data/dataset/metric_dataset.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from torchvision.io.image import read_image
7
+ import torchvision.transforms as tvtf
8
+ from torch.utils.data import Dataset
9
+
10
+ class CenterCrop:
11
+ def __init__(self, size):
12
+ self.size = size
13
+ def __call__(self, image):
14
+ def center_crop_arr(pil_image, image_size):
15
+ """
16
+ Center cropping implementation from ADM.
17
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
18
+ """
19
+ while min(*pil_image.size) >= 2 * image_size:
20
+ pil_image = pil_image.resize(
21
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
22
+ )
23
+
24
+ scale = image_size / min(*pil_image.size)
25
+ pil_image = pil_image.resize(
26
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
27
+ )
28
+
29
+ arr = np.array(pil_image)
30
+ crop_y = (arr.shape[0] - image_size) // 2
31
+ crop_x = (arr.shape[1] - image_size) // 2
32
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
33
+
34
+ return center_crop_arr(image, self.size)
35
+
36
+
37
+ from PIL import Image
38
+ IMG_EXTENSIONS = (
39
+ "*.png",
40
+ "*.JPEG",
41
+ "*.jpeg",
42
+ "*.jpg"
43
+ )
44
+
45
+ def test_collate(batch):
46
+ return torch.stack(batch)
47
+
48
+ class ImageDataset(Dataset):
49
+ def __init__(self, root, image_size=(224, 224)):
50
+ self.root = pathlib.Path(root)
51
+ images = []
52
+ for ext in IMG_EXTENSIONS:
53
+ images.extend(self.root.rglob(ext))
54
+ random.shuffle(images)
55
+ self.images = list(map(lambda x: str(x), images))
56
+ self.transform = tvtf.Compose(
57
+ [
58
+ CenterCrop(image_size[0]),
59
+ tvtf.ToTensor(),
60
+ tvtf.Lambda(lambda x: (x*255).to(torch.uint8)),
61
+ tvtf.Lambda(lambda x: x.expand(3, -1, -1))
62
+ ]
63
+ )
64
+ self.size = image_size
65
+
66
+ def __getitem__(self, idx):
67
+ try:
68
+ image = Image.open(self.images[idx])
69
+ image = self.transform(image)
70
+ except Exception as e:
71
+ print(self.images[idx])
72
+ image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8)
73
+
74
+ # print(image)
75
+ metadata = dict(
76
+ path = self.images[idx],
77
+ root = self.root,
78
+ )
79
+ return image #, metadata
80
+
81
+ def __len__(self):
82
+ return len(self.images)
src/data/dataset/randn.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import random
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+
9
+ class RandomNDataset(Dataset):
10
+ def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ):
11
+ self.selected_classes = selected_classes
12
+ if selected_classes is not None:
13
+ num_classes = len(selected_classes)
14
+ max_num_instances = 10*num_classes
15
+ self.num_classes = num_classes
16
+ self.seeds = seeds
17
+ if seeds is not None:
18
+ self.max_num_instances = len(seeds)*num_classes
19
+ self.num_seeds = len(seeds)
20
+ else:
21
+ self.num_seeds = (max_num_instances + num_classes - 1) // num_classes
22
+ self.max_num_instances = self.num_seeds*num_classes
23
+
24
+ self.latent_shape = latent_shape
25
+
26
+
27
+ def __getitem__(self, idx):
28
+ label = idx // self.num_seeds
29
+ if self.selected_classes:
30
+ label = self.selected_classes[label]
31
+ seed = random.randint(0, 1<<31) #idx % self.num_seeds
32
+ if self.seeds is not None:
33
+ seed = self.seeds[idx % self.num_seeds]
34
+
35
+ # cls_dir = os.path.join(self.root, f"{label}")
36
+ filename = f"{label}_{seed}.png",
37
+ generator = torch.Generator().manual_seed(seed)
38
+ latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
39
+ return latent, label, filename
40
+ def __len__(self):
41
+ return self.max_num_instances
src/data/var_training.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable
3
+ from src.diffusion.base.training import *
4
+ from src.diffusion.base.scheduling import BaseScheduler
5
+ import concurrent.futures
6
+ from concurrent.futures import ProcessPoolExecutor
7
+ from typing import List
8
+ from PIL import Image
9
+ import torch
10
+ import random
11
+ import numpy as np
12
+ import copy
13
+ import torchvision.transforms.functional as tvtf
14
+ from src.models.vae import uint82fp
15
+
16
+
17
+ def center_crop_arr(pil_image, width, height):
18
+ """
19
+ Center cropping implementation from ADM.
20
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
21
+ """
22
+ while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height:
23
+ pil_image = pil_image.resize(
24
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
25
+ )
26
+
27
+ scale = max(width / pil_image.size[0], height / pil_image.size[1])
28
+ pil_image = pil_image.resize(
29
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
30
+ )
31
+ arr = np.array(pil_image)
32
+ crop_y = random.randint(0, (arr.shape[0] - height))
33
+ crop_x = random.randint(0, (arr.shape[1] - width))
34
+ return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width])
35
+
36
+ def process_fn(width, height, data, hflip=0.5):
37
+ image, label = data
38
+ if random.uniform(0, 1) > hflip: # hflip
39
+ image = tvtf.hflip(image)
40
+ image = center_crop_arr(image, width, height) # crop
41
+ image = np.array(image).transpose(2, 0, 1)
42
+ return image, label
43
+
44
+ class VARCandidate:
45
+ def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024):
46
+ self.aspect_ratio = aspect_ratio
47
+ self.width = int(width)
48
+ self.height = int(height)
49
+ self.buffer = buffer
50
+ self.max_buffer_size = max_buffer_size
51
+
52
+ def add_sample(self, data):
53
+ self.buffer.append(data)
54
+ self.buffer = self.buffer[-self.max_buffer_size:]
55
+
56
+ def ready(self, batch_size):
57
+ return len(self.buffer) >= batch_size
58
+
59
+ def get_batch(self, batch_size):
60
+ batch = self.buffer[:batch_size]
61
+ self.buffer = self.buffer[batch_size:]
62
+ batch = [copy.deepcopy(b.result()) for b in batch]
63
+ x, y = zip(*batch)
64
+ x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0)
65
+ x = list(map(uint82fp, x))
66
+ return x, y
67
+
68
+ class VARTransformEngine:
69
+ def __init__(self,
70
+ base_image_size,
71
+ num_aspect_ratios,
72
+ min_aspect_ratio,
73
+ max_aspect_ratio,
74
+ num_workers = 8,
75
+ ):
76
+ self.base_image_size = base_image_size
77
+ self.num_aspect_ratios = num_aspect_ratios
78
+ self.min_aspect_ratio = min_aspect_ratio
79
+ self.max_aspect_ratio = max_aspect_ratio
80
+ self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios)
81
+ self.aspect_ratios = self.aspect_ratios.tolist()
82
+ self.candidates_pool = []
83
+ for i in range(self.num_aspect_ratios):
84
+ candidate = VARCandidate(
85
+ aspect_ratio=self.aspect_ratios[i],
86
+ width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16),
87
+ height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16),
88
+ buffer=[],
89
+ max_buffer_size=1024
90
+ )
91
+ self.candidates_pool.append(candidate)
92
+ self.default_candidate = VARCandidate(
93
+ aspect_ratio=1.0,
94
+ width=self.base_image_size,
95
+ height=self.base_image_size,
96
+ buffer=[],
97
+ max_buffer_size=1024,
98
+ )
99
+ self.executor_pool = ProcessPoolExecutor(max_workers=num_workers)
100
+ self._prefill_count = 100
101
+
102
+ def find_candidate(self, data):
103
+ image = data[0]
104
+ aspect_ratio = image.size[0] / image.size[1]
105
+ min_distance = 1000000
106
+ min_candidate = None
107
+ for candidate in self.candidates_pool:
108
+ dis = abs(aspect_ratio - candidate.aspect_ratio)
109
+ if dis < min_distance:
110
+ min_distance = dis
111
+ min_candidate = candidate
112
+ return min_candidate
113
+
114
+
115
+ def __call__(self, batch_data):
116
+ self._prefill_count -= 1
117
+ if isinstance(batch_data[0], torch.Tensor):
118
+ batch_data[0] = batch_data[0].unbind(0)
119
+
120
+ batch_data = list(zip(*batch_data))
121
+ for data in batch_data:
122
+ candidate = self.find_candidate(data)
123
+ future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data)
124
+ candidate.add_sample(future)
125
+ if self._prefill_count >= 0:
126
+ future = self.executor_pool.submit(process_fn,
127
+ self.default_candidate.width,
128
+ self.default_candidate.height,
129
+ data)
130
+ self.default_candidate.add_sample(future)
131
+
132
+ batch_size = len(batch_data)
133
+ random.shuffle(self.candidates_pool)
134
+ for candidate in self.candidates_pool:
135
+ if candidate.ready(batch_size=batch_size):
136
+ return candidate.get_batch(batch_size=batch_size)
137
+
138
+ # fallback to default 256
139
+ for data in batch_data:
140
+ future = self.executor_pool.submit(process_fn,
141
+ self.default_candidate.width,
142
+ self.default_candidate.height,
143
+ data)
144
+ self.default_candidate.add_sample(future)
145
+ return self.default_candidate.get_batch(batch_size=batch_size)
src/diffusion/__init__.py ADDED
File without changes
src/diffusion/base/guidance.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def simple_guidance_fn(out, cfg):
4
+ uncondition, condtion = out.chunk(2, dim=0)
5
+ out = uncondition + cfg * (condtion - uncondition)
6
+ return out
7
+
8
+ def c3_guidance_fn(out, cfg):
9
+ # guidance function in DiT/SiT, seems like a bug not a feature?
10
+ uncondition, condtion = out.chunk(2, dim=0)
11
+ out = condtion
12
+ out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3])
13
+ return out
14
+
15
+ def c4_guidance_fn(out, cfg):
16
+ # guidance function in DiT/SiT, seems like a bug not a feature?
17
+ uncondition, condition = out.chunk(2, dim=0)
18
+ out = condition
19
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
20
+ out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
21
+ return out
22
+
23
+ def c4_p05_guidance_fn(out, cfg):
24
+ # guidance function in DiT/SiT, seems like a bug not a feature?
25
+ uncondition, condition = out.chunk(2, dim=0)
26
+ out = condition
27
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
28
+ out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
29
+ return out
30
+
31
+ def c4_p10_guidance_fn(out, cfg):
32
+ # guidance function in DiT/SiT, seems like a bug not a feature?
33
+ uncondition, condition = out.chunk(2, dim=0)
34
+ out = condition
35
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
36
+ out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:])
37
+ return out
38
+
39
+ def c4_p15_guidance_fn(out, cfg):
40
+ # guidance function in DiT/SiT, seems like a bug not a feature?
41
+ uncondition, condition = out.chunk(2, dim=0)
42
+ out = condition
43
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
44
+ out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:])
45
+ return out
46
+
47
+ def c4_p20_guidance_fn(out, cfg):
48
+ # guidance function in DiT/SiT, seems like a bug not a feature?
49
+ uncondition, condition = out.chunk(2, dim=0)
50
+ out = condition
51
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
52
+ out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:])
53
+ return out
54
+
55
+ def p4_guidance_fn(out, cfg):
56
+ # guidance function in DiT/SiT, seems like a bug not a feature?
57
+ uncondition, condtion = out.chunk(2, dim=0)
58
+ out = condtion
59
+ out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:])
60
+ return out
src/diffusion/base/sampling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Callable
6
+ from src.diffusion.base.scheduling import BaseScheduler
7
+
8
+ class BaseSampler(nn.Module):
9
+ def __init__(self,
10
+ scheduler: BaseScheduler = None,
11
+ guidance_fn: Callable = None,
12
+ num_steps: int = 250,
13
+ guidance: Union[float, List[float]] = 1.0,
14
+ *args,
15
+ **kwargs
16
+ ):
17
+ super(BaseSampler, self).__init__()
18
+ self.num_steps = num_steps
19
+ self.guidance = guidance
20
+ self.guidance_fn = guidance_fn
21
+ self.scheduler = scheduler
22
+
23
+
24
+ def _impl_sampling(self, net, noise, condition, uncondition):
25
+ raise NotImplementedError
26
+
27
+ def __call__(self, net, noise, condition, uncondition):
28
+ denoised = self._impl_sampling(net, noise, condition, uncondition)
29
+ return denoised
30
+
31
+
src/diffusion/base/scheduling.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ class BaseScheduler:
5
+ def alpha(self, t) -> Tensor:
6
+ ...
7
+ def sigma(self, t) -> Tensor:
8
+ ...
9
+
10
+ def dalpha(self, t) -> Tensor:
11
+ ...
12
+ def dsigma(self, t) -> Tensor:
13
+ ...
14
+
15
+ def dalpha_over_alpha(self, t) -> Tensor:
16
+ return self.dalpha(t) / self.alpha(t)
17
+
18
+ def dsigma_mul_sigma(self, t) -> Tensor:
19
+ return self.dsigma(t)*self.sigma(t)
20
+
21
+ def drift_coefficient(self, t):
22
+ alpha, sigma = self.alpha(t), self.sigma(t)
23
+ dalpha, dsigma = self.dalpha(t), self.dsigma(t)
24
+ return dalpha/(alpha + 1e-6)
25
+
26
+ def diffuse_coefficient(self, t):
27
+ alpha, sigma = self.alpha(t), self.sigma(t)
28
+ dalpha, dsigma = self.dalpha(t), self.dsigma(t)
29
+ return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2
30
+
31
+ def w(self, t):
32
+ return self.sigma(t)
src/diffusion/base/training.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class BaseTrainer(nn.Module):
7
+ def __init__(self,
8
+ null_condition_p=0.1,
9
+ log_var=False,
10
+ ):
11
+ super(BaseTrainer, self).__init__()
12
+ self.null_condition_p = null_condition_p
13
+ self.log_var = log_var
14
+
15
+ def preproprocess(self, raw_iamges, x, condition, uncondition):
16
+ bsz = x.shape[0]
17
+ if self.null_condition_p > 0:
18
+ mask = torch.rand((bsz), device=condition.device) < self.null_condition_p
19
+ mask = mask.expand_as(condition)
20
+ condition[mask] = uncondition[mask]
21
+ return raw_iamges, x, condition
22
+
23
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
24
+ raise NotImplementedError
25
+
26
+ def __call__(self, net, ema_net, raw_images, x, condition, uncondition):
27
+ raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition)
28
+ return self._impl_trainstep(net, ema_net, raw_images, x, condition)
29
+
src/diffusion/ddpm/ddim_sampling.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.diffusion.base.scheduling import *
3
+ from src.diffusion.base.sampling import *
4
+
5
+ from typing import Callable
6
+
7
+ import logging
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class DDIMSampler(BaseSampler):
11
+ def __init__(
12
+ self,
13
+ train_num_steps=1000,
14
+ *args,
15
+ **kwargs
16
+ ):
17
+ super().__init__(*args, **kwargs)
18
+ self.train_num_steps = train_num_steps
19
+ assert self.scheduler is not None
20
+
21
+ def _impl_sampling(self, net, noise, condition, uncondition):
22
+ batch_size = noise.shape[0]
23
+ steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device)
24
+ steps = torch.flip(steps, dims=[0])
25
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
26
+ x = x0 = noise
27
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
28
+ t_cur = t_cur.repeat(batch_size)
29
+ t_next = t_next.repeat(batch_size)
30
+ sigma = self.scheduler.sigma(t_cur)
31
+ alpha = self.scheduler.alpha(t_cur)
32
+ sigma_next = self.scheduler.sigma(t_next)
33
+ alpha_next = self.scheduler.alpha(t_next)
34
+ cfg_x = torch.cat([x, x], dim=0)
35
+ t = t_cur.repeat(2)
36
+ out = net(cfg_x, t, cfg_condition)
37
+ out = self.guidance_fn(out, self.guidance)
38
+ x0 = (x - sigma * out) / alpha
39
+ x = alpha_next * x0 + sigma_next * out
40
+ return x0
src/diffusion/ddpm/scheduling.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from src.diffusion.base.scheduling import *
4
+
5
+
6
+ class DDPMScheduler(BaseScheduler):
7
+ def __init__(
8
+ self,
9
+ beta_min=0.0001,
10
+ beta_max=0.02,
11
+ num_steps=1000,
12
+ ):
13
+ super().__init__()
14
+ self.beta_min = beta_min
15
+ self.beta_max = beta_max
16
+ self.num_steps = num_steps
17
+
18
+ self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
19
+ self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
20
+ self.sigmas_table = 1-self.alphas_table
21
+
22
+
23
+ def beta(self, t) -> Tensor:
24
+ t = t.to(torch.long)
25
+ return self.betas_table[t].view(-1, 1, 1, 1)
26
+
27
+ def alpha(self, t) -> Tensor:
28
+ t = t.to(torch.long)
29
+ return self.alphas_table[t].view(-1, 1, 1, 1)**0.5
30
+
31
+ def sigma(self, t) -> Tensor:
32
+ t = t.to(torch.long)
33
+ return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5
34
+
35
+ def dsigma(self, t) -> Tensor:
36
+ raise NotImplementedError("wrong usage")
37
+
38
+ def dalpha_over_alpha(self, t) ->Tensor:
39
+ raise NotImplementedError("wrong usage")
40
+
41
+ def dsigma_mul_sigma(self, t) ->Tensor:
42
+ raise NotImplementedError("wrong usage")
43
+
44
+ def dalpha(self, t) -> Tensor:
45
+ raise NotImplementedError("wrong usage")
46
+
47
+ def drift_coefficient(self, t):
48
+ raise NotImplementedError("wrong usage")
49
+
50
+ def diffuse_coefficient(self, t):
51
+ raise NotImplementedError("wrong usage")
52
+
53
+ def w(self, t):
54
+ raise NotImplementedError("wrong usage")
55
+
56
+
57
+ class VPScheduler(BaseScheduler):
58
+ def __init__(
59
+ self,
60
+ beta_min=0.1,
61
+ beta_max=20,
62
+ ):
63
+ super().__init__()
64
+ self.beta_min = beta_min
65
+ self.beta_d = beta_max - beta_min
66
+ def beta(self, t) -> Tensor:
67
+ t = torch.clamp(t, min=1e-3, max=1)
68
+ return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)
69
+
70
+ def sigma(self, t) -> Tensor:
71
+ t = torch.clamp(t, min=1e-3, max=1)
72
+ inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
73
+ return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)
74
+
75
+ def dsigma(self, t) -> Tensor:
76
+ raise NotImplementedError("wrong usage")
77
+
78
+ def dalpha_over_alpha(self, t) ->Tensor:
79
+ raise NotImplementedError("wrong usage")
80
+
81
+ def dsigma_mul_sigma(self, t) ->Tensor:
82
+ raise NotImplementedError("wrong usage")
83
+
84
+ def dalpha(self, t) -> Tensor:
85
+ raise NotImplementedError("wrong usage")
86
+
87
+ def alpha(self, t) -> Tensor:
88
+ t = torch.clamp(t, min=1e-3, max=1)
89
+ inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
90
+ return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)
91
+
92
+ def drift_coefficient(self, t):
93
+ raise NotImplementedError("wrong usage")
94
+
95
+ def diffuse_coefficient(self, t):
96
+ raise NotImplementedError("wrong usage")
97
+
98
+ def w(self, t):
99
+ return self.diffuse_coefficient(t)
100
+
101
+
102
+
src/diffusion/ddpm/training.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable
3
+ from src.diffusion.base.training import *
4
+ from src.diffusion.base.scheduling import BaseScheduler
5
+
6
+ def inverse_sigma(alpha, sigma):
7
+ return 1/sigma**2
8
+ def snr(alpha, sigma):
9
+ return alpha/sigma
10
+ def minsnr(alpha, sigma, threshold=5):
11
+ return torch.clip(alpha/sigma, min=threshold)
12
+ def maxsnr(alpha, sigma, threshold=5):
13
+ return torch.clip(alpha/sigma, max=threshold)
14
+ def constant(alpha, sigma):
15
+ return 1
16
+
17
+ class VPTrainer(BaseTrainer):
18
+ def __init__(
19
+ self,
20
+ scheduler: BaseScheduler,
21
+ loss_weight_fn:Callable=constant,
22
+ train_max_t=1000,
23
+ lognorm_t=False,
24
+ *args,
25
+ **kwargs
26
+ ):
27
+ super().__init__(*args, **kwargs)
28
+ self.lognorm_t = lognorm_t
29
+ self.scheduler = scheduler
30
+ self.loss_weight_fn = loss_weight_fn
31
+ self.train_max_t = train_max_t
32
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
33
+ batch_size = x.shape[0]
34
+ if self.lognorm_t:
35
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
36
+ else:
37
+ t = torch.rand(batch_size).to(x.device, x.dtype)
38
+
39
+ noise = torch.randn_like(x)
40
+ alpha = self.scheduler.alpha(t)
41
+ sigma = self.scheduler.sigma(t)
42
+ x_t = alpha * x + noise * sigma
43
+ out = net(x_t, t*self.train_max_t, y)
44
+ weight = self.loss_weight_fn(alpha, sigma)
45
+ loss = weight*(out - noise)**2
46
+
47
+ out = dict(
48
+ loss=loss.mean(),
49
+ )
50
+ return out
51
+
52
+
53
+ class DDPMTrainer(BaseTrainer):
54
+ def __init__(
55
+ self,
56
+ scheduler: BaseScheduler,
57
+ loss_weight_fn: Callable = constant,
58
+ train_max_t=1000,
59
+ lognorm_t=False,
60
+ *args,
61
+ **kwargs
62
+ ):
63
+ super().__init__(*args, **kwargs)
64
+ self.lognorm_t = lognorm_t
65
+ self.scheduler = scheduler
66
+ self.loss_weight_fn = loss_weight_fn
67
+ self.train_max_t = train_max_t
68
+
69
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
70
+ batch_size = x.shape[0]
71
+ t = torch.randint(0, self.train_max_t, (batch_size,))
72
+ noise = torch.randn_like(x)
73
+ alpha = self.scheduler.alpha(t)
74
+ sigma = self.scheduler.sigma(t)
75
+ x_t = alpha * x + noise * sigma
76
+ out = net(x_t, t, y)
77
+ weight = self.loss_weight_fn(alpha, sigma)
78
+ loss = weight * (out - noise) ** 2
79
+
80
+ out = dict(
81
+ loss=loss.mean(),
82
+ )
83
+ return out
src/diffusion/ddpm/vp_sampling.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from src.diffusion.base.scheduling import *
4
+ from src.diffusion.base.sampling import *
5
+ from typing import Callable
6
+
7
+ def ode_step_fn(x, eps, beta, sigma, dt):
8
+ return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt
9
+
10
+ def sde_step_fn(x, eps, beta, sigma, dt):
11
+ return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x)
12
+
13
+ import logging
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class VPEulerSampler(BaseSampler):
17
+ def __init__(
18
+ self,
19
+ train_max_t=1000,
20
+ guidance_fn: Callable = None,
21
+ step_fn: Callable = ode_step_fn,
22
+ last_step=None,
23
+ last_step_fn: Callable = ode_step_fn,
24
+ *args,
25
+ **kwargs
26
+ ):
27
+ super().__init__(*args, **kwargs)
28
+ self.guidance_fn = guidance_fn
29
+ self.step_fn = step_fn
30
+ self.last_step = last_step
31
+ self.last_step_fn = last_step_fn
32
+ self.train_max_t = train_max_t
33
+
34
+ if self.last_step is None or self.num_steps == 1:
35
+ self.last_step = 1.0 / self.num_steps
36
+ assert self.last_step > 0.0
37
+ assert self.scheduler is not None
38
+
39
+ def _impl_sampling(self, net, noise, condition, uncondition):
40
+ batch_size = noise.shape[0]
41
+ steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device)
42
+ steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0)
43
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
44
+ x = noise
45
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
46
+ dt = t_next - t_cur
47
+ t_cur = t_cur.repeat(batch_size)
48
+ sigma = self.scheduler.sigma(t_cur)
49
+ beta = self.scheduler.beta(t_cur)
50
+ cfg_x = torch.cat([x, x], dim=0)
51
+ cfg_t = t_cur.repeat(2)
52
+ out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition)
53
+ eps = self.guidance_fn(out, self.guidance)
54
+ if i < self.num_steps -1 :
55
+ x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0])
56
+ x = self.step_fn(x, eps, beta, sigma, dt)
57
+ else:
58
+ x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step)
59
+ return x
src/diffusion/flow_matching/adam_sampling.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from src.diffusion.base.sampling import *
3
+ from src.diffusion.base.scheduling import *
4
+ from src.diffusion.pre_integral import *
5
+
6
+ from typing import Callable, List, Tuple
7
+
8
+ def ode_step_fn(x, v, dt, s, w):
9
+ return x + v * dt
10
+
11
+ def t2snr(t):
12
+ if isinstance(t, torch.Tensor):
13
+ return (t.clip(min=1e-8)/(1-t + 1e-8))
14
+ if isinstance(t, List) or isinstance(t, Tuple):
15
+ return [t2snr(t) for t in t]
16
+ t = max(t, 1e-8)
17
+ return (t/(1-t + 1e-8))
18
+
19
+ def t2logsnr(t):
20
+ if isinstance(t, torch.Tensor):
21
+ return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
22
+ if isinstance(t, List) or isinstance(t, Tuple):
23
+ return [t2logsnr(t) for t in t]
24
+ t = max(t, 1e-3)
25
+ return math.log(t/(1-t + 1e-3))
26
+
27
+ def t2isnr(t):
28
+ return 1/t2snr(t)
29
+
30
+ def nop(t):
31
+ return t
32
+
33
+ def shift_respace_fn(t, shift=3.0):
34
+ return t / (t + (1 - t) * shift)
35
+
36
+ import logging
37
+ logger = logging.getLogger(__name__)
38
+
39
+ class AdamLMSampler(BaseSampler):
40
+ def __init__(
41
+ self,
42
+ order: int = 2,
43
+ timeshift: float = 1.0,
44
+ lms_transform_fn: Callable = nop,
45
+ w_scheduler: BaseScheduler = None,
46
+ step_fn: Callable = ode_step_fn,
47
+ *args,
48
+ **kwargs
49
+ ):
50
+ super().__init__(*args, **kwargs)
51
+ self.step_fn = step_fn
52
+ self.w_scheduler = w_scheduler
53
+
54
+ assert self.scheduler is not None
55
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
56
+ self.order = order
57
+ self.lms_transform_fn = lms_transform_fn
58
+
59
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
60
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
61
+ self.timesteps = shift_respace_fn(timesteps, timeshift)
62
+ self.timedeltas = timesteps[1:] - self.timesteps[:-1]
63
+ self._reparameterize_coeffs()
64
+
65
+ def _reparameterize_coeffs(self):
66
+ solver_coeffs = [[] for _ in range(self.num_steps)]
67
+ for i in range(0, self.num_steps):
68
+ pre_vs = [1.0, ]*(i+1)
69
+ pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
70
+ int_t_start = self.lms_transform_fn(self.timesteps[i])
71
+ int_t_end = self.lms_transform_fn(self.timesteps[i+1])
72
+
73
+ order_annealing = self.order #self.num_steps - i
74
+ order = min(self.order, i + 1, order_annealing)
75
+
76
+ _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
77
+ solver_coeffs[i] = coeffs
78
+ self.solver_coeffs = solver_coeffs
79
+
80
+ def _impl_sampling(self, net, noise, condition, uncondition):
81
+ """
82
+ sampling process of Euler sampler
83
+ -
84
+ """
85
+ batch_size = noise.shape[0]
86
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
87
+ x = x0 = noise
88
+ pred_trajectory = []
89
+ t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
90
+ timedeltas = self.timedeltas
91
+ solver_coeffs = self.solver_coeffs
92
+ for i in range(self.num_steps):
93
+ cfg_x = torch.cat([x, x], dim=0)
94
+ cfg_t = t_cur.repeat(2)
95
+ out = net(cfg_x, cfg_t, cfg_condition)
96
+ out = self.guidance_fn(out, self.guidances[i])
97
+ pred_trajectory.append(out)
98
+ out = torch.zeros_like(out)
99
+ order = len(self.solver_coeffs[i])
100
+ for j in range(order):
101
+ out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
102
+ v = out
103
+ dt = timedeltas[i]
104
+ x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
105
+ x = self.step_fn(x, v, dt, s=0, w=0)
106
+ t_cur += dt
107
+ return x
src/diffusion/flow_matching/sampling.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from src.diffusion.base.guidance import *
4
+ from src.diffusion.base.scheduling import *
5
+ from src.diffusion.base.sampling import *
6
+
7
+ from typing import Callable
8
+
9
+
10
+ def shift_respace_fn(t, shift=3.0):
11
+ return t / (t + (1 - t) * shift)
12
+
13
+ def ode_step_fn(x, v, dt, s, w):
14
+ return x + v * dt
15
+
16
+ def sde_mean_step_fn(x, v, dt, s, w):
17
+ return x + v * dt + s * w * dt
18
+
19
+ def sde_step_fn(x, v, dt, s, w):
20
+ return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
21
+
22
+ def sde_preserve_step_fn(x, v, dt, s, w):
23
+ return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
24
+
25
+
26
+ import logging
27
+ logger = logging.getLogger(__name__)
28
+
29
+ class EulerSampler(BaseSampler):
30
+ def __init__(
31
+ self,
32
+ w_scheduler: BaseScheduler = None,
33
+ timeshift=1.0,
34
+ step_fn: Callable = ode_step_fn,
35
+ last_step=None,
36
+ last_step_fn: Callable = ode_step_fn,
37
+ *args,
38
+ **kwargs
39
+ ):
40
+ super().__init__(*args, **kwargs)
41
+ self.step_fn = step_fn
42
+ self.last_step = last_step
43
+ self.last_step_fn = last_step_fn
44
+ self.w_scheduler = w_scheduler
45
+ self.timeshift = timeshift
46
+
47
+ if self.last_step is None or self.num_steps == 1:
48
+ self.last_step = 1.0 / self.num_steps
49
+
50
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
51
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
52
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
53
+
54
+ assert self.last_step > 0.0
55
+ assert self.scheduler is not None
56
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
57
+ if self.w_scheduler is not None:
58
+ if self.step_fn == ode_step_fn:
59
+ logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
60
+
61
+ def _impl_sampling(self, net, noise, condition, uncondition):
62
+ """
63
+ sampling process of Euler sampler
64
+ -
65
+ """
66
+ batch_size = noise.shape[0]
67
+ steps = self.timesteps.to(noise.device)
68
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
69
+ x = noise
70
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
71
+ dt = t_next - t_cur
72
+ t_cur = t_cur.repeat(batch_size)
73
+ sigma = self.scheduler.sigma(t_cur)
74
+ dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
75
+ dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
76
+ if self.w_scheduler:
77
+ w = self.w_scheduler.w(t_cur)
78
+ else:
79
+ w = 0.0
80
+
81
+ cfg_x = torch.cat([x, x], dim=0)
82
+ cfg_t = t_cur.repeat(2)
83
+ out = net(cfg_x, cfg_t, cfg_condition)
84
+ out = self.guidance_fn(out, self.guidance)
85
+ v = out
86
+ s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
87
+ if i < self.num_steps -1 :
88
+ x = self.step_fn(x, v, dt, s=s, w=w)
89
+ else:
90
+ x = self.last_step_fn(x, v, dt, s=s, w=w)
91
+ return x
92
+
93
+
94
+ class HeunSampler(BaseSampler):
95
+ def __init__(
96
+ self,
97
+ scheduler: BaseScheduler = None,
98
+ w_scheduler: BaseScheduler = None,
99
+ exact_henu=False,
100
+ timeshift=1.0,
101
+ step_fn: Callable = ode_step_fn,
102
+ last_step=None,
103
+ last_step_fn: Callable = ode_step_fn,
104
+ *args,
105
+ **kwargs
106
+ ):
107
+ super().__init__(*args, **kwargs)
108
+ self.scheduler = scheduler
109
+ self.exact_henu = exact_henu
110
+ self.step_fn = step_fn
111
+ self.last_step = last_step
112
+ self.last_step_fn = last_step_fn
113
+ self.w_scheduler = w_scheduler
114
+ self.timeshift = timeshift
115
+
116
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
117
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
118
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
119
+
120
+ if self.last_step is None or self.num_steps == 1:
121
+ self.last_step = 1.0 / self.num_steps
122
+ assert self.last_step > 0.0
123
+ assert self.scheduler is not None
124
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
125
+ if self.w_scheduler is not None:
126
+ if self.step_fn == ode_step_fn:
127
+ logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
128
+
129
+ def _impl_sampling(self, net, noise, condition, uncondition):
130
+ """
131
+ sampling process of Henu sampler
132
+ -
133
+ """
134
+ batch_size = noise.shape[0]
135
+ steps = self.timesteps.to(noise.device)
136
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
137
+ x = noise
138
+ v_hat, s_hat = 0.0, 0.0
139
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
140
+ dt = t_next - t_cur
141
+ t_cur = t_cur.repeat(batch_size)
142
+ sigma = self.scheduler.sigma(t_cur)
143
+ alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur)
144
+ dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
145
+ t_hat = t_next
146
+ t_hat = t_hat.repeat(batch_size)
147
+ sigma_hat = self.scheduler.sigma(t_hat)
148
+ alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat)
149
+ dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat)
150
+
151
+ if self.w_scheduler:
152
+ w = self.w_scheduler.w(t_cur)
153
+ else:
154
+ w = 0.0
155
+ if i == 0 or self.exact_henu:
156
+ cfg_x = torch.cat([x, x], dim=0)
157
+ cfg_t_cur = t_cur.repeat(2)
158
+ out = net(cfg_x, cfg_t_cur, cfg_condition)
159
+ out = self.guidance_fn(out, self.guidance)
160
+ v = out
161
+ s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma)
162
+ else:
163
+ v = v_hat
164
+ s = s_hat
165
+ x_hat = self.step_fn(x, v, dt, s=s, w=w)
166
+ # henu correct
167
+ if i < self.num_steps -1:
168
+ cfg_x_hat = torch.cat([x_hat, x_hat], dim=0)
169
+ cfg_t_hat = t_hat.repeat(2)
170
+ out = net(cfg_x_hat, cfg_t_hat, cfg_condition)
171
+ out = self.guidance_fn(out, self.guidance)
172
+ v_hat = out
173
+ s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat)
174
+ v = (v + v_hat) / 2
175
+ s = (s + s_hat) / 2
176
+ x = self.step_fn(x, v, dt, s=s, w=w)
177
+ else:
178
+ x = self.last_step_fn(x, v, dt, s=s, w=w)
179
+ return x
src/diffusion/flow_matching/scheduling.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from src.diffusion.base.scheduling import *
4
+
5
+
6
+ class LinearScheduler(BaseScheduler):
7
+ def alpha(self, t) -> Tensor:
8
+ return (t).view(-1, 1, 1, 1)
9
+ def sigma(self, t) -> Tensor:
10
+ return (1-t).view(-1, 1, 1, 1)
11
+ def dalpha(self, t) -> Tensor:
12
+ return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
13
+ def dsigma(self, t) -> Tensor:
14
+ return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
15
+
16
+ # SoTA for ImageNet!
17
+ class GVPScheduler(BaseScheduler):
18
+ def alpha(self, t) -> Tensor:
19
+ return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
20
+ def sigma(self, t) -> Tensor:
21
+ return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
22
+ def dalpha(self, t) -> Tensor:
23
+ return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
24
+ def dsigma(self, t) -> Tensor:
25
+ return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
26
+ def w(self, t):
27
+ return torch.sin(t)**2
28
+
29
+ class ConstScheduler(BaseScheduler):
30
+ def w(self, t):
31
+ return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
32
+
33
+ from src.diffusion.ddpm.scheduling import VPScheduler
34
+ class VPBetaScheduler(VPScheduler):
35
+ def w(self, t):
36
+ return self.beta(t).view(-1, 1, 1, 1)
37
+
38
+
39
+
src/diffusion/flow_matching/training.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable
3
+ from src.diffusion.base.training import *
4
+ from src.diffusion.base.scheduling import BaseScheduler
5
+
6
+ def inverse_sigma(alpha, sigma):
7
+ return 1/sigma**2
8
+ def snr(alpha, sigma):
9
+ return alpha/sigma
10
+ def minsnr(alpha, sigma, threshold=5):
11
+ return torch.clip(alpha/sigma, min=threshold)
12
+ def maxsnr(alpha, sigma, threshold=5):
13
+ return torch.clip(alpha/sigma, max=threshold)
14
+ def constant(alpha, sigma):
15
+ return 1
16
+
17
+ class FlowMatchingTrainer(BaseTrainer):
18
+ def __init__(
19
+ self,
20
+ scheduler: BaseScheduler,
21
+ loss_weight_fn:Callable=constant,
22
+ lognorm_t=False,
23
+ *args,
24
+ **kwargs
25
+ ):
26
+ super().__init__(*args, **kwargs)
27
+ self.lognorm_t = lognorm_t
28
+ self.scheduler = scheduler
29
+ self.loss_weight_fn = loss_weight_fn
30
+
31
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
32
+ batch_size = x.shape[0]
33
+ if self.lognorm_t:
34
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
35
+ else:
36
+ t = torch.rand(batch_size).to(x.device, x.dtype)
37
+ noise = torch.randn_like(x)
38
+ alpha = self.scheduler.alpha(t)
39
+ dalpha = self.scheduler.dalpha(t)
40
+ sigma = self.scheduler.sigma(t)
41
+ dsigma = self.scheduler.dsigma(t)
42
+ w = self.scheduler.w(t)
43
+
44
+ x_t = alpha * x + noise * sigma
45
+ v_t = dalpha * x + dsigma * noise
46
+ out = net(x_t, t, y)
47
+
48
+ weight = self.loss_weight_fn(alpha, sigma)
49
+
50
+ loss = weight*(out - v_t)**2
51
+
52
+ out = dict(
53
+ loss=loss.mean(),
54
+ )
55
+ return out
src/diffusion/flow_matching/training_cos.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable
3
+ from src.diffusion.base.training import *
4
+ from src.diffusion.base.scheduling import BaseScheduler
5
+
6
+ def inverse_sigma(alpha, sigma):
7
+ return 1/sigma**2
8
+ def snr(alpha, sigma):
9
+ return alpha/sigma
10
+ def minsnr(alpha, sigma, threshold=5):
11
+ return torch.clip(alpha/sigma, min=threshold)
12
+ def maxsnr(alpha, sigma, threshold=5):
13
+ return torch.clip(alpha/sigma, max=threshold)
14
+ def constant(alpha, sigma):
15
+ return 1
16
+
17
+ class COSTrainer(BaseTrainer):
18
+ def __init__(
19
+ self,
20
+ scheduler: BaseScheduler,
21
+ loss_weight_fn:Callable=constant,
22
+ lognorm_t=False,
23
+ *args,
24
+ **kwargs
25
+ ):
26
+ super().__init__(*args, **kwargs)
27
+ self.lognorm_t = lognorm_t
28
+ self.scheduler = scheduler
29
+ self.loss_weight_fn = loss_weight_fn
30
+
31
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
32
+ batch_size = x.shape[0]
33
+ if self.lognorm_t:
34
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
35
+ else:
36
+ t = torch.rand(batch_size).to(x.device, x.dtype)
37
+ noise = torch.randn_like(x)
38
+ alpha = self.scheduler.alpha(t)
39
+ dalpha = self.scheduler.dalpha(t)
40
+ sigma = self.scheduler.sigma(t)
41
+ dsigma = self.scheduler.dsigma(t)
42
+ w = self.scheduler.w(t)
43
+
44
+ x_t = alpha * x + noise * sigma
45
+ v_t = dalpha * x + dsigma * noise
46
+ out = net(x_t, t, y)
47
+
48
+ weight = self.loss_weight_fn(alpha, sigma)
49
+
50
+ fm_loss = weight*(out - v_t)**2
51
+ cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1)
52
+ cos_loss = 1 - cos_sim
53
+
54
+ out = dict(
55
+ fm_loss=fm_loss.mean(),
56
+ cos_loss=cos_loss.mean(),
57
+ loss=fm_loss.mean() + cos_loss.mean(),
58
+ )
59
+ return out
src/diffusion/flow_matching/training_repa.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import timm
4
+ from torch.nn import Parameter
5
+
6
+ from src.utils.no_grad import no_grad
7
+ from typing import Callable, Iterator, Tuple
8
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9
+ from torchvision.transforms import Normalize
10
+ from src.diffusion.base.training import *
11
+ from src.diffusion.base.scheduling import BaseScheduler
12
+
13
+ def inverse_sigma(alpha, sigma):
14
+ return 1/sigma**2
15
+ def snr(alpha, sigma):
16
+ return alpha/sigma
17
+ def minsnr(alpha, sigma, threshold=5):
18
+ return torch.clip(alpha/sigma, min=threshold)
19
+ def maxsnr(alpha, sigma, threshold=5):
20
+ return torch.clip(alpha/sigma, max=threshold)
21
+ def constant(alpha, sigma):
22
+ return 1
23
+
24
+
25
+ class DINOv2(nn.Module):
26
+ def __init__(self, weight_path:str):
27
+ super(DINOv2, self).__init__()
28
+ self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
29
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
30
+ self.encoder.head = torch.nn.Identity()
31
+ self.patch_size = self.encoder.patch_embed.patch_size
32
+ self.precomputed_pos_embed = dict()
33
+
34
+ def fetch_pos(self, h, w):
35
+ key = (h, w)
36
+ if key in self.precomputed_pos_embed:
37
+ return self.precomputed_pos_embed[key]
38
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
39
+ self.pos_embed.data, [h, w],
40
+ )
41
+ self.precomputed_pos_embed[key] = value
42
+ return value
43
+
44
+ def forward(self, x):
45
+ b, c, h, w = x.shape
46
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
47
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
48
+ b, c, h, w = x.shape
49
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
50
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
51
+ self.encoder.pos_embed.data = pos_embed_data
52
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
53
+ return feature
54
+
55
+
56
+ class REPATrainer(BaseTrainer):
57
+ def __init__(
58
+ self,
59
+ scheduler: BaseScheduler,
60
+ loss_weight_fn:Callable=constant,
61
+ feat_loss_weight: float=0.5,
62
+ lognorm_t=False,
63
+ encoder_weight_path=None,
64
+ align_layer=8,
65
+ proj_denoiser_dim=256,
66
+ proj_hidden_dim=256,
67
+ proj_encoder_dim=256,
68
+ *args,
69
+ **kwargs
70
+ ):
71
+ super().__init__(*args, **kwargs)
72
+ self.lognorm_t = lognorm_t
73
+ self.scheduler = scheduler
74
+ self.loss_weight_fn = loss_weight_fn
75
+ self.feat_loss_weight = feat_loss_weight
76
+ self.align_layer = align_layer
77
+ self.encoder = DINOv2(encoder_weight_path)
78
+ no_grad(self.encoder)
79
+
80
+ self.proj = nn.Sequential(
81
+ nn.Sequential(
82
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
83
+ nn.SiLU(),
84
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
85
+ nn.SiLU(),
86
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
87
+ )
88
+ )
89
+
90
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
91
+ batch_size, c, height, width = x.shape
92
+ if self.lognorm_t:
93
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
94
+ else:
95
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
96
+ t = base_t
97
+
98
+ noise = torch.randn_like(x)
99
+ alpha = self.scheduler.alpha(t)
100
+ dalpha = self.scheduler.dalpha(t)
101
+ sigma = self.scheduler.sigma(t)
102
+ dsigma = self.scheduler.dsigma(t)
103
+
104
+ x_t = alpha * x + noise * sigma
105
+ v_t = dalpha * x + dsigma * noise
106
+
107
+ src_feature = []
108
+ def forward_hook(net, input, output):
109
+ src_feature.append(output)
110
+ handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
111
+
112
+ out = net(x_t, t, y)
113
+ src_feature = self.proj(src_feature[0])
114
+ handle.remove()
115
+
116
+ with torch.no_grad():
117
+ dst_feature = self.encoder(raw_images)
118
+
119
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
120
+ cos_loss = 1 - cos_sim
121
+
122
+ weight = self.loss_weight_fn(alpha, sigma)
123
+ fm_loss = weight*(out - v_t)**2
124
+
125
+ out = dict(
126
+ fm_loss=fm_loss.mean(),
127
+ cos_loss=cos_loss.mean(),
128
+ loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
129
+ )
130
+ return out
131
+
132
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
133
+ self.proj.state_dict(
134
+ destination=destination,
135
+ prefix=prefix + "proj.",
136
+ keep_vars=keep_vars)
137
+
src/diffusion/pre_integral.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # lagrange interpolation
4
+ def lagrange_preint_o1(t1, v1, int_t_start, int_t_end):
5
+ '''
6
+ lagrange interpolation of order 1
7
+ Args:
8
+ t1: timestepx
9
+ v1: value field at t1
10
+ int_t_start: intergation start time
11
+ int_t_end: intergation end time
12
+ Returns:
13
+ integrated value
14
+ '''
15
+ int1 = (int_t_end-int_t_start)
16
+ return int1*v1, (int1/int1, )
17
+
18
+ def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end):
19
+ '''
20
+ lagrange interpolation of order 2
21
+ Args:
22
+ t1: timestepx
23
+ t2: timestepy
24
+ v1: value field at t1
25
+ v2: value field at t2
26
+ int_t_start: intergation start time
27
+ int_t_end: intergation end time
28
+ Returns:
29
+ integrated value
30
+ '''
31
+ int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2)
32
+ int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2)
33
+ int_sum = int1+int2
34
+ return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum)
35
+
36
+ def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end):
37
+ '''
38
+ lagrange interpolation of order 3
39
+ Args:
40
+ t1: timestepx
41
+ t2: timestepy
42
+ t3: timestepz
43
+ v1: value field at t1
44
+ v2: value field at t2
45
+ v3: value field at t3
46
+ int_t_start: intergation start time
47
+ int_t_end: intergation end time
48
+ Returns:
49
+ integrated value
50
+ '''
51
+ int1_denom = (t1-t2)*(t1-t3)
52
+ int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end
53
+ int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start
54
+ int1 = (int1_end - int1_start)/int1_denom
55
+ int2_denom = (t2-t1)*(t2-t3)
56
+ int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end
57
+ int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start
58
+ int2 = (int2_end - int2_start)/int2_denom
59
+ int3_denom = (t3-t1)*(t3-t2)
60
+ int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end
61
+ int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start
62
+ int3 = (int3_end - int3_start)/int3_denom
63
+ int_sum = int1+int2+int3
64
+ return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum)
65
+
66
+ def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end):
67
+ '''
68
+ lagrange interpolation of order 4
69
+ Args:
70
+ t1: timestepx
71
+ t2: timestepy
72
+ t3: timestepz
73
+ t4: timestepw
74
+ v1: value field at t1
75
+ v2: value field at t2
76
+ v3: value field at t3
77
+ v4: value field at t4
78
+ int_t_start: intergation start time
79
+ int_t_end: intergation end time
80
+ Returns:
81
+ integrated value
82
+ '''
83
+ int1_denom = (t1-t2)*(t1-t3)*(t1-t4)
84
+ int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end
85
+ int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start
86
+ int1 = (int1_end - int1_start)/int1_denom
87
+ int2_denom = (t2-t1)*(t2-t3)*(t2-t4)
88
+ int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end
89
+ int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start
90
+ int2 = (int2_end - int2_start)/int2_denom
91
+ int3_denom = (t3-t1)*(t3-t2)*(t3-t4)
92
+ int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end
93
+ int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start
94
+ int3 = (int3_end - int3_start)/int3_denom
95
+ int4_denom = (t4-t1)*(t4-t2)*(t4-t3)
96
+ int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end
97
+ int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start
98
+ int4 = (int4_end - int4_start)/int4_denom
99
+ int_sum = int1+int2+int3+int4
100
+ return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum)
101
+
102
+
103
+ def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end):
104
+ '''
105
+ lagrange interpolation
106
+ Args:
107
+ order: order of interpolation
108
+ pre_vs: value field at pre_ts
109
+ pre_ts: timesteps
110
+ int_t_start: intergation start time
111
+ int_t_end: intergation end time
112
+ Returns:
113
+ integrated value
114
+ '''
115
+ order = min(order, len(pre_vs), len(pre_ts))
116
+ if order == 1:
117
+ return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end)
118
+ elif order == 2:
119
+ return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
120
+ elif order == 3:
121
+ return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
122
+ elif order == 4:
123
+ return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
124
+ else:
125
+ raise ValueError('Invalid order')
126
+
127
+
128
+ def polynomial_integral(coeffs, int_t_start, int_t_end):
129
+ '''
130
+ polynomial integral
131
+ Args:
132
+ coeffs: coefficients of the polynomial
133
+ int_t_start: intergation start time
134
+ int_t_end: intergation end time
135
+ Returns:
136
+ integrated value
137
+ '''
138
+ orders = len(coeffs)
139
+ int_val = 0
140
+ for o in range(orders):
141
+ int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1))
142
+ return int_val
143
+
src/diffusion/stateful_flow_matching/adam_sampling.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from src.diffusion.base.sampling import *
3
+ from src.diffusion.base.scheduling import *
4
+ from src.diffusion.pre_integral import *
5
+
6
+ from typing import Callable, List, Tuple
7
+
8
+ def ode_step_fn(x, v, dt, s, w):
9
+ return x + v * dt
10
+
11
+ def t2snr(t):
12
+ if isinstance(t, torch.Tensor):
13
+ return (t.clip(min=1e-8)/(1-t + 1e-8))
14
+ if isinstance(t, List) or isinstance(t, Tuple):
15
+ return [t2snr(t) for t in t]
16
+ t = max(t, 1e-8)
17
+ return (t/(1-t + 1e-8))
18
+
19
+ def t2logsnr(t):
20
+ if isinstance(t, torch.Tensor):
21
+ return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
22
+ if isinstance(t, List) or isinstance(t, Tuple):
23
+ return [t2logsnr(t) for t in t]
24
+ t = max(t, 1e-3)
25
+ return math.log(t/(1-t + 1e-3))
26
+
27
+ def t2isnr(t):
28
+ return 1/t2snr(t)
29
+
30
+ def nop(t):
31
+ return t
32
+
33
+ def shift_respace_fn(t, shift=3.0):
34
+ return t / (t + (1 - t) * shift)
35
+
36
+ import logging
37
+ logger = logging.getLogger(__name__)
38
+
39
+ class AdamLMSampler(BaseSampler):
40
+ def __init__(
41
+ self,
42
+ order: int = 2,
43
+ timeshift: float = 1.0,
44
+ state_refresh_rate: int = 1,
45
+ lms_transform_fn: Callable = nop,
46
+ w_scheduler: BaseScheduler = None,
47
+ step_fn: Callable = ode_step_fn,
48
+ *args,
49
+ **kwargs
50
+ ):
51
+ super().__init__(*args, **kwargs)
52
+ self.step_fn = step_fn
53
+ self.w_scheduler = w_scheduler
54
+ self.state_refresh_rate = state_refresh_rate
55
+
56
+ assert self.scheduler is not None
57
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
58
+ self.order = order
59
+ self.lms_transform_fn = lms_transform_fn
60
+
61
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
62
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
63
+ self.timesteps = shift_respace_fn(timesteps, timeshift)
64
+ self.timedeltas = timesteps[1:] - self.timesteps[:-1]
65
+ self._reparameterize_coeffs()
66
+
67
+ def _reparameterize_coeffs(self):
68
+ solver_coeffs = [[] for _ in range(self.num_steps)]
69
+ for i in range(0, self.num_steps):
70
+ pre_vs = [1.0, ]*(i+1)
71
+ pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
72
+ int_t_start = self.lms_transform_fn(self.timesteps[i])
73
+ int_t_end = self.lms_transform_fn(self.timesteps[i+1])
74
+
75
+ order_annealing = self.order #self.num_steps - i
76
+ order = min(self.order, i + 1, order_annealing)
77
+
78
+ _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
79
+ solver_coeffs[i] = coeffs
80
+ self.solver_coeffs = solver_coeffs
81
+
82
+ def _impl_sampling(self, net, noise, condition, uncondition):
83
+ """
84
+ sampling process of Euler sampler
85
+ -
86
+ """
87
+ batch_size = noise.shape[0]
88
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
89
+ x = x0 = noise
90
+ state = None
91
+ pred_trajectory = []
92
+ t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
93
+ timedeltas = self.timedeltas
94
+ solver_coeffs = self.solver_coeffs
95
+ for i in range(self.num_steps):
96
+ cfg_x = torch.cat([x, x], dim=0)
97
+ cfg_t = t_cur.repeat(2)
98
+ if i % self.state_refresh_rate == 0:
99
+ state = None
100
+ out, state = net(cfg_x, cfg_t, cfg_condition, state)
101
+ out = self.guidance_fn(out, self.guidances[i])
102
+ pred_trajectory.append(out)
103
+ out = torch.zeros_like(out)
104
+ order = len(self.solver_coeffs[i])
105
+ for j in range(order):
106
+ out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
107
+ v = out
108
+ dt = timedeltas[i]
109
+ x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
110
+ x = self.step_fn(x, v, dt, s=0, w=0)
111
+ t_cur += dt
112
+ return x
src/diffusion/stateful_flow_matching/sampling.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from src.diffusion.base.guidance import *
4
+ from src.diffusion.base.scheduling import *
5
+ from src.diffusion.base.sampling import *
6
+
7
+ from typing import Callable
8
+
9
+
10
+ def shift_respace_fn(t, shift=3.0):
11
+ return t / (t + (1 - t) * shift)
12
+
13
+ def ode_step_fn(x, v, dt, s, w):
14
+ return x + v * dt
15
+
16
+ def sde_mean_step_fn(x, v, dt, s, w):
17
+ return x + v * dt + s * w * dt
18
+
19
+ def sde_step_fn(x, v, dt, s, w):
20
+ return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
21
+
22
+ def sde_preserve_step_fn(x, v, dt, s, w):
23
+ return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
24
+
25
+
26
+ import logging
27
+ logger = logging.getLogger(__name__)
28
+
29
+ class EulerSampler(BaseSampler):
30
+ def __init__(
31
+ self,
32
+ w_scheduler: BaseScheduler = None,
33
+ timeshift=1.0,
34
+ guidance_interval_min: float = 0.0,
35
+ guidance_interval_max: float = 1.0,
36
+ state_refresh_rate=1,
37
+ step_fn: Callable = ode_step_fn,
38
+ last_step=None,
39
+ last_step_fn: Callable = ode_step_fn,
40
+ *args,
41
+ **kwargs
42
+ ):
43
+ super().__init__(*args, **kwargs)
44
+ self.step_fn = step_fn
45
+ self.last_step = last_step
46
+ self.last_step_fn = last_step_fn
47
+ self.w_scheduler = w_scheduler
48
+ self.timeshift = timeshift
49
+ self.state_refresh_rate = state_refresh_rate
50
+ self.guidance_interval_min = guidance_interval_min
51
+ self.guidance_interval_max = guidance_interval_max
52
+
53
+ if self.last_step is None or self.num_steps == 1:
54
+ self.last_step = 1.0 / self.num_steps
55
+
56
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
57
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
58
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
59
+
60
+ assert self.last_step > 0.0
61
+ assert self.scheduler is not None
62
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
63
+ if self.w_scheduler is not None:
64
+ if self.step_fn == ode_step_fn:
65
+ logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
66
+
67
+ def _impl_sampling(self, net, noise, condition, uncondition):
68
+ """
69
+ sampling process of Euler sampler
70
+ -
71
+ """
72
+ batch_size = noise.shape[0]
73
+ steps = self.timesteps.to(noise.device)
74
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
75
+ x = noise
76
+ state = None
77
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
78
+ dt = t_next - t_cur
79
+ t_cur = t_cur.repeat(batch_size)
80
+ sigma = self.scheduler.sigma(t_cur)
81
+ dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
82
+ dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
83
+ if self.w_scheduler:
84
+ w = self.w_scheduler.w(t_cur)
85
+ else:
86
+ w = 0.0
87
+
88
+ cfg_x = torch.cat([x, x], dim=0)
89
+ cfg_t = t_cur.repeat(2)
90
+ if i % self.state_refresh_rate == 0:
91
+ state = None
92
+ out, state = net(cfg_x, cfg_t, cfg_condition, state)
93
+ if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
94
+ out = self.guidance_fn(out, self.guidance)
95
+ else:
96
+ out = self.guidance_fn(out, 1.0)
97
+ v = out
98
+ s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
99
+ if i < self.num_steps -1 :
100
+ x = self.step_fn(x, v, dt, s=s, w=w)
101
+ else:
102
+ x = self.last_step_fn(x, v, dt, s=s, w=w)
103
+ return x
src/diffusion/stateful_flow_matching/scheduling.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from src.diffusion.base.scheduling import *
4
+
5
+
6
+ class LinearScheduler(BaseScheduler):
7
+ def alpha(self, t) -> Tensor:
8
+ return (t).view(-1, 1, 1, 1)
9
+ def sigma(self, t) -> Tensor:
10
+ return (1-t).view(-1, 1, 1, 1)
11
+ def dalpha(self, t) -> Tensor:
12
+ return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
13
+ def dsigma(self, t) -> Tensor:
14
+ return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
15
+
16
+ # SoTA for ImageNet!
17
+ class GVPScheduler(BaseScheduler):
18
+ def alpha(self, t) -> Tensor:
19
+ return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
20
+ def sigma(self, t) -> Tensor:
21
+ return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
22
+ def dalpha(self, t) -> Tensor:
23
+ return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
24
+ def dsigma(self, t) -> Tensor:
25
+ return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
26
+ def w(self, t):
27
+ return torch.sin(t)**2
28
+
29
+ class ConstScheduler(BaseScheduler):
30
+ def w(self, t):
31
+ return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
32
+
33
+ from src.diffusion.ddpm.scheduling import VPScheduler
34
+ class VPBetaScheduler(VPScheduler):
35
+ def w(self, t):
36
+ return self.beta(t).view(-1, 1, 1, 1)
37
+
38
+
39
+
src/diffusion/stateful_flow_matching/sharing_sampling.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from src.diffusion.base.guidance import *
4
+ from src.diffusion.base.scheduling import *
5
+ from src.diffusion.base.sampling import *
6
+
7
+ from typing import Callable
8
+
9
+
10
+ def shift_respace_fn(t, shift=3.0):
11
+ return t / (t + (1 - t) * shift)
12
+
13
+ def ode_step_fn(x, v, dt, s, w):
14
+ return x + v * dt
15
+
16
+
17
+ import logging
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class EulerSampler(BaseSampler):
21
+ def __init__(
22
+ self,
23
+ w_scheduler: BaseScheduler = None,
24
+ timeshift=1.0,
25
+ guidance_interval_min: float = 0.0,
26
+ guidance_interval_max: float = 1.0,
27
+ state_refresh_rate=1,
28
+ step_fn: Callable = ode_step_fn,
29
+ last_step=None,
30
+ last_step_fn: Callable = ode_step_fn,
31
+ *args,
32
+ **kwargs
33
+ ):
34
+ super().__init__(*args, **kwargs)
35
+ self.step_fn = step_fn
36
+ self.last_step = last_step
37
+ self.last_step_fn = last_step_fn
38
+ self.w_scheduler = w_scheduler
39
+ self.timeshift = timeshift
40
+ self.state_refresh_rate = state_refresh_rate
41
+ self.guidance_interval_min = guidance_interval_min
42
+ self.guidance_interval_max = guidance_interval_max
43
+
44
+ if self.last_step is None or self.num_steps == 1:
45
+ self.last_step = 1.0 / self.num_steps
46
+
47
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
48
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
49
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
50
+
51
+ assert self.last_step > 0.0
52
+ assert self.scheduler is not None
53
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
54
+ if self.w_scheduler is not None:
55
+ if self.step_fn == ode_step_fn:
56
+ logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
57
+
58
+ # init recompute
59
+ self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
60
+ self.recompute_timesteps = list(range(self.num_steps))
61
+
62
+ def sharing_dp(self, net, noise, condition, uncondition):
63
+ _, C, H, W = noise.shape
64
+ B = 8
65
+ template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
66
+ template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
67
+ template_uncondition = torch.full((B, ), 1000, device=condition.device)
68
+ _, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
69
+ states = torch.stack(state_list)
70
+ N, B, L, C = states.shape
71
+ states = states.view(N, B*L, C )
72
+ states = states.permute(1, 0, 2)
73
+ states = torch.nn.functional.normalize(states, dim=-1)
74
+ with torch.autocast(device_type="cuda", dtype=torch.float64):
75
+ sim = torch.bmm(states, states.transpose(1, 2))
76
+ sim = torch.mean(sim, dim=0).cpu()
77
+ error_map = (1-sim).tolist()
78
+
79
+ # init cum-error
80
+ for i in range(1, self.num_steps):
81
+ for j in range(0, i):
82
+ error_map[i][j] = error_map[i-1][j] + error_map[i][j]
83
+
84
+ # init dp and force 0 start
85
+ C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
86
+ P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
87
+ for i in range(1, self.num_steps+1):
88
+ C[1][i] = error_map[i - 1][0]
89
+ P[1][i] = 0
90
+
91
+ # dp state
92
+ for step in range(2, self.num_recompute_timesteps+1):
93
+ for i in range(step, self.num_steps+1):
94
+ min_value = 99999
95
+ min_index = -1
96
+ for j in range(step-1, i):
97
+ value = C[step-1][j] + error_map[i-1][j]
98
+ if value < min_value:
99
+ min_value = value
100
+ min_index = j
101
+ C[step][i] = min_value
102
+ P[step][i] = min_index
103
+
104
+ # trace back
105
+ timesteps = [self.num_steps,]
106
+ for i in range(self.num_recompute_timesteps, 0, -1):
107
+ idx = timesteps[-1]
108
+ timesteps.append(P[i][idx])
109
+ timesteps.reverse()
110
+
111
+ print("recompute timesteps solved by DP: ", timesteps)
112
+ return timesteps[:-1]
113
+
114
+ def _impl_sampling(self, net, noise, condition, uncondition):
115
+ """
116
+ sampling process of Euler sampler
117
+ -
118
+ """
119
+ batch_size = noise.shape[0]
120
+ steps = self.timesteps.to(noise.device)
121
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
122
+ x = noise
123
+ state = None
124
+ pooled_state_list = []
125
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
126
+ dt = t_next - t_cur
127
+ t_cur = t_cur.repeat(batch_size)
128
+ cfg_x = torch.cat([x, x], dim=0)
129
+ cfg_t = t_cur.repeat(2)
130
+ if i in self.recompute_timesteps:
131
+ state = None
132
+ out, state = net(cfg_x, cfg_t, cfg_condition, state)
133
+ if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
134
+ out = self.guidance_fn(out, self.guidance)
135
+ else:
136
+ out = self.guidance_fn(out, 1.0)
137
+ v = out
138
+ if i < self.num_steps -1 :
139
+ x = self.step_fn(x, v, dt, s=0.0, w=0.0)
140
+ else:
141
+ x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
142
+ pooled_state_list.append(state)
143
+ return x, pooled_state_list
144
+
145
+ def __call__(self, net, noise, condition, uncondition):
146
+ if len(self.recompute_timesteps) != self.num_recompute_timesteps:
147
+ self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
148
+ denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
149
+ return denoised
src/diffusion/stateful_flow_matching/training.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable
3
+ from src.diffusion.base.training import *
4
+ from src.diffusion.base.scheduling import BaseScheduler
5
+
6
+ def inverse_sigma(alpha, sigma):
7
+ return 1/sigma**2
8
+ def snr(alpha, sigma):
9
+ return alpha/sigma
10
+ def minsnr(alpha, sigma, threshold=5):
11
+ return torch.clip(alpha/sigma, min=threshold)
12
+ def maxsnr(alpha, sigma, threshold=5):
13
+ return torch.clip(alpha/sigma, max=threshold)
14
+ def constant(alpha, sigma):
15
+ return 1
16
+
17
+ class FlowMatchingTrainer(BaseTrainer):
18
+ def __init__(
19
+ self,
20
+ scheduler: BaseScheduler,
21
+ loss_weight_fn:Callable=constant,
22
+ lognorm_t=False,
23
+ *args,
24
+ **kwargs
25
+ ):
26
+ super().__init__(*args, **kwargs)
27
+ self.lognorm_t = lognorm_t
28
+ self.scheduler = scheduler
29
+ self.loss_weight_fn = loss_weight_fn
30
+
31
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
32
+ batch_size = x.shape[0]
33
+ if self.lognorm_t:
34
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
35
+ else:
36
+ t = torch.rand(batch_size).to(x.device, x.dtype)
37
+ noise = torch.randn_like(x)
38
+ alpha = self.scheduler.alpha(t)
39
+ dalpha = self.scheduler.dalpha(t)
40
+ sigma = self.scheduler.sigma(t)
41
+ dsigma = self.scheduler.dsigma(t)
42
+ w = self.scheduler.w(t)
43
+
44
+ x_t = alpha * x + noise * sigma
45
+ v_t = dalpha * x + dsigma * noise
46
+ out, _ = net(x_t, t, y)
47
+
48
+ weight = self.loss_weight_fn(alpha, sigma)
49
+
50
+ loss = weight*(out - v_t)**2
51
+
52
+ out = dict(
53
+ loss=loss.mean(),
54
+ )
55
+ return out
src/diffusion/stateful_flow_matching/training_repa.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import timm
4
+ from torch.nn import Parameter
5
+
6
+ from src.utils.no_grad import no_grad
7
+ from typing import Callable, Iterator, Tuple
8
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9
+ from torchvision.transforms import Normalize
10
+ from src.diffusion.base.training import *
11
+ from src.diffusion.base.scheduling import BaseScheduler
12
+
13
+ def inverse_sigma(alpha, sigma):
14
+ return 1/sigma**2
15
+ def snr(alpha, sigma):
16
+ return alpha/sigma
17
+ def minsnr(alpha, sigma, threshold=5):
18
+ return torch.clip(alpha/sigma, min=threshold)
19
+ def maxsnr(alpha, sigma, threshold=5):
20
+ return torch.clip(alpha/sigma, max=threshold)
21
+ def constant(alpha, sigma):
22
+ return 1
23
+
24
+
25
+ class DINOv2(nn.Module):
26
+ def __init__(self, weight_path:str):
27
+ super(DINOv2, self).__init__()
28
+ self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
29
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
30
+ self.encoder.head = torch.nn.Identity()
31
+ self.patch_size = self.encoder.patch_embed.patch_size
32
+ self.precomputed_pos_embed = dict()
33
+
34
+ def fetch_pos(self, h, w):
35
+ key = (h, w)
36
+ if key in self.precomputed_pos_embed:
37
+ return self.precomputed_pos_embed[key]
38
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
39
+ self.pos_embed.data, [h, w],
40
+ )
41
+ self.precomputed_pos_embed[key] = value
42
+ return value
43
+
44
+ def forward(self, x):
45
+ b, c, h, w = x.shape
46
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
47
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
48
+ b, c, h, w = x.shape
49
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
50
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
51
+ self.encoder.pos_embed.data = pos_embed_data
52
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
53
+ return feature
54
+
55
+
56
+ class REPATrainer(BaseTrainer):
57
+ def __init__(
58
+ self,
59
+ scheduler: BaseScheduler,
60
+ loss_weight_fn:Callable=constant,
61
+ feat_loss_weight: float=0.5,
62
+ lognorm_t=False,
63
+ encoder_weight_path=None,
64
+ align_layer=8,
65
+ proj_denoiser_dim=256,
66
+ proj_hidden_dim=256,
67
+ proj_encoder_dim=256,
68
+ *args,
69
+ **kwargs
70
+ ):
71
+ super().__init__(*args, **kwargs)
72
+ self.lognorm_t = lognorm_t
73
+ self.scheduler = scheduler
74
+ self.loss_weight_fn = loss_weight_fn
75
+ self.feat_loss_weight = feat_loss_weight
76
+ self.align_layer = align_layer
77
+ self.encoder = DINOv2(encoder_weight_path)
78
+ self.proj_encoder_dim = proj_encoder_dim
79
+ no_grad(self.encoder)
80
+
81
+ self.proj = nn.Sequential(
82
+ nn.Sequential(
83
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
84
+ nn.SiLU(),
85
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
86
+ nn.SiLU(),
87
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
88
+ )
89
+ )
90
+
91
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
92
+ batch_size, c, height, width = x.shape
93
+ if self.lognorm_t:
94
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
95
+ else:
96
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
97
+ t = base_t
98
+
99
+ noise = torch.randn_like(x)
100
+ alpha = self.scheduler.alpha(t)
101
+ dalpha = self.scheduler.dalpha(t)
102
+ sigma = self.scheduler.sigma(t)
103
+ dsigma = self.scheduler.dsigma(t)
104
+
105
+ x_t = alpha * x + noise * sigma
106
+ v_t = dalpha * x + dsigma * noise
107
+ src_feature = []
108
+ def forward_hook(net, input, output):
109
+ src_feature.append(output)
110
+
111
+ if getattr(net, "blocks", None) is not None:
112
+ handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
113
+ else:
114
+ handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
115
+
116
+ out, _ = net(x_t, t, y)
117
+ src_feature = self.proj(src_feature[0])
118
+ handle.remove()
119
+
120
+ with torch.no_grad():
121
+ dst_feature = self.encoder(raw_images)
122
+
123
+ if dst_feature.shape[1] != src_feature.shape[1]:
124
+ dst_length = dst_feature.shape[1]
125
+ rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
126
+ dst_height = (dst_length)**0.5 * (height/width)**0.5
127
+ dst_width = (dst_length)**0.5 * (width/height)**0.5
128
+ dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
129
+ dst_feature = dst_feature.permute(0, 3, 1, 2)
130
+ dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
131
+ dst_feature = dst_feature.permute(0, 2, 3, 1)
132
+ dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
133
+
134
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
135
+ cos_loss = 1 - cos_sim
136
+
137
+ weight = self.loss_weight_fn(alpha, sigma)
138
+ fm_loss = weight*(out - v_t)**2
139
+
140
+ out = dict(
141
+ fm_loss=fm_loss.mean(),
142
+ cos_loss=cos_loss.mean(),
143
+ loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
144
+ )
145
+ return out
146
+
147
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
148
+ self.proj.state_dict(
149
+ destination=destination,
150
+ prefix=prefix + "proj.",
151
+ keep_vars=keep_vars)
152
+
src/lightning_data.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ import copy
4
+ import lightning.pytorch as pl
5
+ from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
6
+ from torch.utils.data import DataLoader
7
+ from src.data.dataset.randn import RandomNDataset
8
+ from src.data.var_training import VARTransformEngine
9
+
10
+ def collate_fn(batch):
11
+ new_batch = copy.deepcopy(batch)
12
+ new_batch = list(zip(*new_batch))
13
+ for i in range(len(new_batch)):
14
+ if isinstance(new_batch[i][0], torch.Tensor):
15
+ try:
16
+ new_batch[i] = torch.stack(new_batch[i], dim=0)
17
+ except:
18
+ print("Warning: could not stack tensors")
19
+ return new_batch
20
+
21
+ class DataModule(pl.LightningDataModule):
22
+ def __init__(self,
23
+ train_root,
24
+ test_nature_root,
25
+ test_gen_root,
26
+ train_image_size=64,
27
+ train_batch_size=64,
28
+ train_num_workers=8,
29
+ var_transform_engine: VARTransformEngine = None,
30
+ train_prefetch_factor=2,
31
+ train_dataset: str = None,
32
+ eval_batch_size=32,
33
+ eval_num_workers=4,
34
+ eval_max_num_instances=50000,
35
+ pred_batch_size=32,
36
+ pred_num_workers=4,
37
+ pred_seeds:str=None,
38
+ pred_selected_classes=None,
39
+ num_classes=1000,
40
+ latent_shape=(4,64,64),
41
+ ):
42
+ super().__init__()
43
+ pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None
44
+
45
+ self.train_root = train_root
46
+ self.train_image_size = train_image_size
47
+ self.train_dataset = train_dataset
48
+ # stupid data_convert override, just to make nebular happy
49
+ self.train_batch_size = train_batch_size
50
+ self.train_num_workers = train_num_workers
51
+ self.train_prefetch_factor = train_prefetch_factor
52
+
53
+ self.test_nature_root = test_nature_root
54
+ self.test_gen_root = test_gen_root
55
+ self.eval_max_num_instances = eval_max_num_instances
56
+ self.pred_seeds = pred_seeds
57
+ self.num_classes = num_classes
58
+ self.latent_shape = latent_shape
59
+
60
+ self.eval_batch_size = eval_batch_size
61
+ self.pred_batch_size = pred_batch_size
62
+
63
+ self.pred_num_workers = pred_num_workers
64
+ self.eval_num_workers = eval_num_workers
65
+
66
+ self.pred_selected_classes = pred_selected_classes
67
+
68
+ self._train_dataloader = None
69
+ self.var_transform_engine = var_transform_engine
70
+
71
+ def setup(self, stage: str) -> None:
72
+ if stage == "fit":
73
+ assert self.train_dataset is not None
74
+ if self.train_dataset == "pix_imagenet64":
75
+ from src.data.dataset.imagenet import PixImageNet64
76
+ self.train_dataset = PixImageNet64(
77
+ root=self.train_root,
78
+ )
79
+ elif self.train_dataset == "pix_imagenet128":
80
+ from src.data.dataset.imagenet import PixImageNet128
81
+ self.train_dataset = PixImageNet128(
82
+ root=self.train_root,
83
+ )
84
+ elif self.train_dataset == "imagenet256":
85
+ from src.data.dataset.imagenet import ImageNet256
86
+ self.train_dataset = ImageNet256(
87
+ root=self.train_root,
88
+ )
89
+ elif self.train_dataset == "pix_imagenet256":
90
+ from src.data.dataset.imagenet import PixImageNet256
91
+ self.train_dataset = PixImageNet256(
92
+ root=self.train_root,
93
+ )
94
+ elif self.train_dataset == "imagenet512":
95
+ from src.data.dataset.imagenet import ImageNet512
96
+ self.train_dataset = ImageNet512(
97
+ root=self.train_root,
98
+ )
99
+ elif self.train_dataset == "pix_imagenet512":
100
+ from src.data.dataset.imagenet import PixImageNet512
101
+ self.train_dataset = PixImageNet512(
102
+ root=self.train_root,
103
+ )
104
+ else:
105
+ raise NotImplementedError("no such dataset")
106
+
107
+ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
108
+ if self.var_transform_engine and self.trainer.training:
109
+ batch = self.var_transform_engine(batch)
110
+ return batch
111
+
112
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
113
+ global_rank = self.trainer.global_rank
114
+ world_size = self.trainer.world_size
115
+ from torch.utils.data import DistributedSampler
116
+ sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
117
+ self._train_dataloader = DataLoader(
118
+ self.train_dataset,
119
+ self.train_batch_size,
120
+ timeout=6000,
121
+ num_workers=self.train_num_workers,
122
+ prefetch_factor=self.train_prefetch_factor,
123
+ sampler=sampler,
124
+ collate_fn=collate_fn,
125
+ )
126
+ return self._train_dataloader
127
+
128
+ def val_dataloader(self) -> EVAL_DATALOADERS:
129
+ global_rank = self.trainer.global_rank
130
+ world_size = self.trainer.world_size
131
+ self.eval_dataset = RandomNDataset(
132
+ latent_shape=self.latent_shape,
133
+ num_classes=self.num_classes,
134
+ max_num_instances=self.eval_max_num_instances,
135
+ )
136
+ from torch.utils.data import DistributedSampler
137
+ sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
138
+ return DataLoader(self.eval_dataset, self.eval_batch_size,
139
+ num_workers=self.eval_num_workers,
140
+ prefetch_factor=2,
141
+ collate_fn=collate_fn,
142
+ sampler=sampler
143
+ )
144
+
145
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
146
+ global_rank = self.trainer.global_rank
147
+ world_size = self.trainer.world_size
148
+ self.pred_dataset = RandomNDataset(
149
+ seeds= self.pred_seeds,
150
+ max_num_instances=50000,
151
+ num_classes=self.num_classes,
152
+ selected_classes=self.pred_selected_classes,
153
+ latent_shape=self.latent_shape,
154
+ )
155
+ from torch.utils.data import DistributedSampler
156
+ sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
157
+ return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size,
158
+ num_workers=self.pred_num_workers,
159
+ prefetch_factor=4,
160
+ collate_fn=collate_fn,
161
+ sampler=sampler
162
+ )
src/lightning_model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict
2
+ import os.path
3
+ import copy
4
+ import torch
5
+ import torch.nn as nn
6
+ import lightning.pytorch as pl
7
+ from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT
8
+ from torch.optim.lr_scheduler import LRScheduler
9
+ from torch.optim import Optimizer
10
+ from lightning.pytorch.callbacks import Callback
11
+
12
+
13
+ from src.models.vae import BaseVAE, fp2uint8
14
+ from src.models.conditioner import BaseConditioner
15
+ from src.utils.model_loader import ModelLoader
16
+ from src.callbacks.simple_ema import SimpleEMA
17
+ from src.diffusion.base.sampling import BaseSampler
18
+ from src.diffusion.base.training import BaseTrainer
19
+ from src.utils.no_grad import no_grad, filter_nograd_tensors
20
+ from src.utils.copy import copy_params
21
+
22
+ EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA]
23
+ OptimizerCallable = Callable[[Iterable], Optimizer]
24
+ LRSchedulerCallable = Callable[[Optimizer], LRScheduler]
25
+
26
+
27
+ class LightningModel(pl.LightningModule):
28
+ def __init__(self,
29
+ vae: BaseVAE,
30
+ conditioner: BaseConditioner,
31
+ denoiser: nn.Module,
32
+ diffusion_trainer: BaseTrainer,
33
+ diffusion_sampler: BaseSampler,
34
+ ema_tracker: Optional[EMACallable] = None,
35
+ optimizer: OptimizerCallable = None,
36
+ lr_scheduler: LRSchedulerCallable = None,
37
+ ):
38
+ super().__init__()
39
+ self.vae = vae
40
+ self.conditioner = conditioner
41
+ self.denoiser = denoiser
42
+ self.ema_denoiser = copy.deepcopy(self.denoiser)
43
+ self.diffusion_sampler = diffusion_sampler
44
+ self.diffusion_trainer = diffusion_trainer
45
+ self.ema_tracker = ema_tracker
46
+ self.optimizer = optimizer
47
+ self.lr_scheduler = lr_scheduler
48
+ # self.model_loader = ModelLoader()
49
+
50
+ self._strict_loading = False
51
+
52
+ def configure_model(self) -> None:
53
+ self.trainer.strategy.barrier()
54
+ # self.denoiser = self.model_loader.load(self.denoiser)
55
+ copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser)
56
+
57
+ # self.denoiser = torch.compile(self.denoiser)
58
+ # disable grad for conditioner and vae
59
+ no_grad(self.conditioner)
60
+ no_grad(self.vae)
61
+ no_grad(self.diffusion_sampler)
62
+ no_grad(self.ema_denoiser)
63
+
64
+ def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
65
+ ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
66
+ return [ema_tracker]
67
+
68
+ def configure_optimizers(self) -> OptimizerLRScheduler:
69
+ params_denoiser = filter_nograd_tensors(self.denoiser.parameters())
70
+ params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters())
71
+ optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser])
72
+ if self.lr_scheduler is None:
73
+ return dict(
74
+ optimizer=optimizer
75
+ )
76
+ else:
77
+ lr_scheduler = self.lr_scheduler(optimizer)
78
+ return dict(
79
+ optimizer=optimizer,
80
+ lr_scheduler=lr_scheduler
81
+ )
82
+
83
+ def training_step(self, batch, batch_idx):
84
+ raw_images, x, y = batch
85
+ with torch.no_grad():
86
+ x = self.vae.encode(x)
87
+ condition, uncondition = self.conditioner(y)
88
+ loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition)
89
+ self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False)
90
+ return loss["loss"]
91
+
92
+ def predict_step(self, batch, batch_idx):
93
+ xT, y, metadata = batch
94
+ with torch.no_grad():
95
+ condition, uncondition = self.conditioner(y)
96
+ # Sample images:
97
+ samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition)
98
+ samples = self.vae.decode(samples)
99
+ # fp32 -1,1 -> uint8 0,255
100
+ samples = fp2uint8(samples)
101
+ return samples
102
+
103
+ def validation_step(self, batch, batch_idx):
104
+ samples = self.predict_step(batch, batch_idx)
105
+ return samples
106
+
107
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
108
+ if destination is None:
109
+ destination = {}
110
+ self._save_to_state_dict(destination, prefix, keep_vars)
111
+ self.denoiser.state_dict(
112
+ destination=destination,
113
+ prefix=prefix+"denoiser.",
114
+ keep_vars=keep_vars)
115
+ self.ema_denoiser.state_dict(
116
+ destination=destination,
117
+ prefix=prefix+"ema_denoiser.",
118
+ keep_vars=keep_vars)
119
+ self.diffusion_trainer.state_dict(
120
+ destination=destination,
121
+ prefix=prefix+"diffusion_trainer.",
122
+ keep_vars=keep_vars)
123
+ return destination
src/models/__init__.py ADDED
File without changes
src/models/conditioner.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class BaseConditioner(nn.Module):
5
+ def __init__(self):
6
+ super(BaseConditioner, self).__init__()
7
+
8
+ def _impl_condition(self, y):
9
+ ...
10
+ def _impl_uncondition(self, y):
11
+ ...
12
+ def __call__(self, y):
13
+ condition = self._impl_condition(y)
14
+ uncondition = self._impl_uncondition(y)
15
+ return condition, uncondition
16
+
17
+ class LabelConditioner(BaseConditioner):
18
+ def __init__(self, null_class):
19
+ super().__init__()
20
+ self.null_condition = null_class
21
+
22
+ def _impl_condition(self, y):
23
+ return torch.tensor(y).long().cuda()
24
+
25
+ def _impl_uncondition(self, y):
26
+ return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda()
src/models/denoiser/__init__.py ADDED
File without changes
src/models/denoiser/decoupled_improved_dit.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+ from torch.nn.init import zeros_
8
+ from torch.nn.modules.module import T
9
+
10
+ # from torch.nn.attention.flex_attention import flex_attention, create_block_mask
11
+ from torch.nn.functional import scaled_dot_product_attention
12
+
13
+
14
+ def modulate(x, shift, scale):
15
+ return x * (1 + scale) + shift
16
+
17
+ class Embed(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_chans: int = 3,
21
+ embed_dim: int = 768,
22
+ norm_layer = None,
23
+ bias: bool = True,
24
+ ):
25
+ super().__init__()
26
+ self.in_chans = in_chans
27
+ self.embed_dim = embed_dim
28
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
29
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
30
+ def forward(self, x):
31
+ x = self.proj(x)
32
+ x = self.norm(x)
33
+ return x
34
+
35
+ class TimestepEmbedder(nn.Module):
36
+
37
+ def __init__(self, hidden_size, frequency_embedding_size=256):
38
+ super().__init__()
39
+ self.mlp = nn.Sequential(
40
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
41
+ nn.SiLU(),
42
+ nn.Linear(hidden_size, hidden_size, bias=True),
43
+ )
44
+ self.frequency_embedding_size = frequency_embedding_size
45
+
46
+ @staticmethod
47
+ def timestep_embedding(t, dim, max_period=10):
48
+ half = dim // 2
49
+ freqs = torch.exp(
50
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
51
+ )
52
+ args = t[..., None].float() * freqs[None, ...]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+ class LabelEmbedder(nn.Module):
64
+ def __init__(self, num_classes, hidden_size):
65
+ super().__init__()
66
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
67
+ self.num_classes = num_classes
68
+
69
+ def forward(self, labels,):
70
+ embeddings = self.embedding_table(labels)
71
+ return embeddings
72
+
73
+ class FinalLayer(nn.Module):
74
+ def __init__(self, hidden_size, out_channels):
75
+ super().__init__()
76
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
77
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
78
+ self.adaLN_modulation = nn.Sequential(
79
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
80
+ )
81
+
82
+ def forward(self, x, c):
83
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
84
+ x = modulate(self.norm_final(x), shift, scale)
85
+ x = self.linear(x)
86
+ return x
87
+
88
+ class RMSNorm(nn.Module):
89
+ def __init__(self, hidden_size, eps=1e-6):
90
+ """
91
+ LlamaRMSNorm is equivalent to T5LayerNorm
92
+ """
93
+ super().__init__()
94
+ self.weight = nn.Parameter(torch.ones(hidden_size))
95
+ self.variance_epsilon = eps
96
+
97
+ def forward(self, hidden_states):
98
+ input_dtype = hidden_states.dtype
99
+ hidden_states = hidden_states.to(torch.float32)
100
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
101
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
102
+ return self.weight * hidden_states.to(input_dtype)
103
+
104
+ class FeedForward(nn.Module):
105
+ def __init__(
106
+ self,
107
+ dim: int,
108
+ hidden_dim: int,
109
+ ):
110
+ super().__init__()
111
+ hidden_dim = int(2 * hidden_dim / 3)
112
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
113
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
114
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
115
+ def forward(self, x):
116
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
117
+ return x
118
+
119
+ def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
120
+ # assert H * H == end
121
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
122
+ x_pos = torch.linspace(0, scale, width)
123
+ y_pos = torch.linspace(0, scale, height)
124
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
125
+ y_pos = y_pos.reshape(-1)
126
+ x_pos = x_pos.reshape(-1)
127
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
128
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
129
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
130
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
131
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
132
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
133
+ freqs_cis = freqs_cis.reshape(height*width, -1)
134
+ return freqs_cis
135
+
136
+
137
+ def apply_rotary_emb(
138
+ xq: torch.Tensor,
139
+ xk: torch.Tensor,
140
+ freqs_cis: torch.Tensor,
141
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
142
+ freqs_cis = freqs_cis[None, :, None, :]
143
+ # xq : B N H Hc
144
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
145
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
146
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
147
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
148
+ return xq_out.type_as(xq), xk_out.type_as(xk)
149
+
150
+
151
+ class RAttention(nn.Module):
152
+ def __init__(
153
+ self,
154
+ dim: int,
155
+ num_heads: int = 8,
156
+ qkv_bias: bool = False,
157
+ qk_norm: bool = True,
158
+ attn_drop: float = 0.,
159
+ proj_drop: float = 0.,
160
+ norm_layer: nn.Module = RMSNorm,
161
+ ) -> None:
162
+ super().__init__()
163
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
164
+
165
+ self.dim = dim
166
+ self.num_heads = num_heads
167
+ self.head_dim = dim // num_heads
168
+ self.scale = self.head_dim ** -0.5
169
+
170
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
171
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
172
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+
177
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
178
+ B, N, C = x.shape
179
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
180
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
181
+ q = self.q_norm(q)
182
+ k = self.k_norm(k)
183
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
184
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
185
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
186
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
187
+
188
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
189
+
190
+ x = x.transpose(1, 2).reshape(B, N, C)
191
+ x = self.proj(x)
192
+ x = self.proj_drop(x)
193
+ return x
194
+
195
+
196
+
197
+ class DDTBlock(nn.Module):
198
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
199
+ super().__init__()
200
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
201
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
202
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
203
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
204
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
205
+ self.adaLN_modulation = nn.Sequential(
206
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
207
+ )
208
+
209
+ def forward(self, x, c, pos, mask=None):
210
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
211
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
212
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
213
+ return x
214
+
215
+
216
+ class DDT(nn.Module):
217
+ def __init__(
218
+ self,
219
+ in_channels=4,
220
+ num_groups=12,
221
+ hidden_size=1152,
222
+ num_blocks=18,
223
+ num_encoder_blocks=4,
224
+ patch_size=2,
225
+ num_classes=1000,
226
+ learn_sigma=True,
227
+ deep_supervision=0,
228
+ weight_path=None,
229
+ load_ema=False,
230
+ ):
231
+ super().__init__()
232
+ self.deep_supervision = deep_supervision
233
+ self.learn_sigma = learn_sigma
234
+ self.in_channels = in_channels
235
+ self.out_channels = in_channels
236
+ self.hidden_size = hidden_size
237
+ self.num_groups = num_groups
238
+ self.num_blocks = num_blocks
239
+ self.num_encoder_blocks = num_encoder_blocks
240
+ self.patch_size = patch_size
241
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
242
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
243
+ self.t_embedder = TimestepEmbedder(hidden_size)
244
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
245
+
246
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
247
+
248
+ self.weight_path = weight_path
249
+
250
+ self.load_ema = load_ema
251
+ self.blocks = nn.ModuleList([
252
+ DDTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
253
+ ])
254
+ self.initialize_weights()
255
+ self.precompute_pos = dict()
256
+
257
+ def fetch_pos(self, height, width, device):
258
+ if (height, width) in self.precompute_pos:
259
+ return self.precompute_pos[(height, width)].to(device)
260
+ else:
261
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
262
+ self.precompute_pos[(height, width)] = pos
263
+ return pos
264
+
265
+ def initialize_weights(self):
266
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
267
+ w = self.x_embedder.proj.weight.data
268
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
269
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
270
+
271
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
272
+ w = self.s_embedder.proj.weight.data
273
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
274
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
275
+
276
+ # Initialize label embedding table:
277
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
278
+
279
+ # Initialize timestep embedding MLP:
280
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
281
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
282
+
283
+ # Zero-out output layers:
284
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
285
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
286
+ nn.init.constant_(self.final_layer.linear.weight, 0)
287
+ nn.init.constant_(self.final_layer.linear.bias, 0)
288
+
289
+
290
+ def forward(self, x, t, y, s=None, mask=None):
291
+ B, _, H, W = x.shape
292
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
293
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
294
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
295
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
296
+ c = nn.functional.silu(t + y)
297
+ if s is None:
298
+ s = self.s_embedder(x)
299
+ for i in range(self.num_encoder_blocks):
300
+ s = self.blocks[i](s, c, pos, mask)
301
+ s = nn.functional.silu(t + s)
302
+
303
+ x = self.x_embedder(x)
304
+ for i in range(self.num_encoder_blocks, self.num_blocks):
305
+ x = self.blocks[i](x, s, pos, None)
306
+ x = self.final_layer(x, s)
307
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
308
+ return x, s
src/models/denoiser/improved_dit.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+ from torch.nn.init import zeros_
8
+ from torch.nn.modules.module import T
9
+
10
+ # from torch.nn.attention.flex_attention import flex_attention, create_block_mask
11
+ from torch.nn.functional import scaled_dot_product_attention
12
+
13
+
14
+ def modulate(x, shift, scale):
15
+ return x * (1 + scale) + shift
16
+
17
+ class Embed(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_chans: int = 3,
21
+ embed_dim: int = 768,
22
+ norm_layer = None,
23
+ bias: bool = True,
24
+ ):
25
+ super().__init__()
26
+ self.in_chans = in_chans
27
+ self.embed_dim = embed_dim
28
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
29
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
30
+ def forward(self, x):
31
+ x = self.proj(x)
32
+ x = self.norm(x)
33
+ return x
34
+
35
+ class TimestepEmbedder(nn.Module):
36
+
37
+ def __init__(self, hidden_size, frequency_embedding_size=256):
38
+ super().__init__()
39
+ self.mlp = nn.Sequential(
40
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
41
+ nn.SiLU(),
42
+ nn.Linear(hidden_size, hidden_size, bias=True),
43
+ )
44
+ self.frequency_embedding_size = frequency_embedding_size
45
+
46
+ @staticmethod
47
+ def timestep_embedding(t, dim, max_period=10):
48
+ half = dim // 2
49
+ freqs = torch.exp(
50
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
51
+ )
52
+ args = t[..., None].float() * freqs[None, ...]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+ class LabelEmbedder(nn.Module):
64
+ def __init__(self, num_classes, hidden_size):
65
+ super().__init__()
66
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
67
+ self.num_classes = num_classes
68
+
69
+ def forward(self, labels,):
70
+ embeddings = self.embedding_table(labels)
71
+ return embeddings
72
+
73
+ class FinalLayer(nn.Module):
74
+ def __init__(self, hidden_size, out_channels):
75
+ super().__init__()
76
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
77
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
78
+ self.adaLN_modulation = nn.Sequential(
79
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
80
+ )
81
+
82
+ def forward(self, x, c):
83
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
84
+ x = modulate(self.norm_final(x), shift, scale)
85
+ x = self.linear(x)
86
+ return x
87
+
88
+ class RMSNorm(nn.Module):
89
+ def __init__(self, hidden_size, eps=1e-6):
90
+ """
91
+ LlamaRMSNorm is equivalent to T5LayerNorm
92
+ """
93
+ super().__init__()
94
+ self.weight = nn.Parameter(torch.ones(hidden_size))
95
+ self.variance_epsilon = eps
96
+
97
+ def forward(self, hidden_states):
98
+ input_dtype = hidden_states.dtype
99
+ hidden_states = hidden_states.to(torch.float32)
100
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
101
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
102
+ return (self.weight * hidden_states).to(input_dtype)
103
+
104
+ class FeedForward(nn.Module):
105
+ def __init__(
106
+ self,
107
+ dim: int,
108
+ hidden_dim: int,
109
+ ):
110
+ super().__init__()
111
+ hidden_dim = int(2 * hidden_dim / 3)
112
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
113
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
114
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
115
+ def forward(self, x):
116
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
117
+ return x
118
+
119
+ def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
120
+ # assert H * H == end
121
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
122
+ x_pos = torch.linspace(0, scale, width)
123
+ y_pos = torch.linspace(0, scale, height)
124
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
125
+ y_pos = y_pos.reshape(-1)
126
+ x_pos = x_pos.reshape(-1)
127
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
128
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
129
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
130
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
131
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
132
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
133
+ freqs_cis = freqs_cis.reshape(height*width, -1)
134
+ return freqs_cis
135
+
136
+
137
+ def apply_rotary_emb(
138
+ xq: torch.Tensor,
139
+ xk: torch.Tensor,
140
+ freqs_cis: torch.Tensor,
141
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
142
+ freqs_cis = freqs_cis[None, :, None, :]
143
+ # xq : B N H Hc
144
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
145
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
146
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
147
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
148
+ return xq_out.type_as(xq), xk_out.type_as(xk)
149
+
150
+
151
+ class RAttention(nn.Module):
152
+ def __init__(
153
+ self,
154
+ dim: int,
155
+ num_heads: int = 8,
156
+ qkv_bias: bool = False,
157
+ qk_norm: bool = True,
158
+ attn_drop: float = 0.,
159
+ proj_drop: float = 0.,
160
+ norm_layer: nn.Module = RMSNorm,
161
+ ) -> None:
162
+ super().__init__()
163
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
164
+
165
+ self.dim = dim
166
+ self.num_heads = num_heads
167
+ self.head_dim = dim // num_heads
168
+ self.scale = self.head_dim ** -0.5
169
+
170
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
171
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
172
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+
177
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
178
+ B, N, C = x.shape
179
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
180
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
181
+ q = self.q_norm(q)
182
+ k = self.k_norm(k)
183
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
184
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
185
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
186
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
187
+
188
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
189
+
190
+ x = x.transpose(1, 2).reshape(B, N, C)
191
+ x = self.proj(x)
192
+ x = self.proj_drop(x)
193
+ return x
194
+
195
+
196
+
197
+ class DiTBlock(nn.Module):
198
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
199
+ super().__init__()
200
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
201
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
202
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
203
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
204
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
205
+ self.adaLN_modulation = nn.Sequential(
206
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
207
+ )
208
+
209
+ def forward(self, x, c, pos, mask=None):
210
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
211
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
212
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
213
+ return x
214
+
215
+
216
+ class DiT(nn.Module):
217
+ def __init__(
218
+ self,
219
+ in_channels=4,
220
+ num_groups=12,
221
+ hidden_size=1152,
222
+ num_blocks=18,
223
+ patch_size=2,
224
+ num_classes=1000,
225
+ learn_sigma=True,
226
+ deep_supervision=0,
227
+ weight_path=None,
228
+ load_ema=False,
229
+ ):
230
+ super().__init__()
231
+ self.deep_supervision = deep_supervision
232
+ self.learn_sigma = learn_sigma
233
+ self.in_channels = in_channels
234
+ self.out_channels = in_channels
235
+ self.hidden_size = hidden_size
236
+ self.num_groups = num_groups
237
+ self.num_blocks = num_blocks
238
+ self.patch_size = patch_size
239
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
240
+ self.t_embedder = TimestepEmbedder(hidden_size)
241
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
242
+
243
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
244
+
245
+ self.weight_path = weight_path
246
+
247
+ self.load_ema = load_ema
248
+ self.blocks = nn.ModuleList([
249
+ DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
250
+ ])
251
+ self.initialize_weights()
252
+ self.precompute_pos = dict()
253
+
254
+ def fetch_pos(self, height, width, device, dtype):
255
+ if (height, width) in self.precompute_pos:
256
+ return self.precompute_pos[(height, width)].to(device, dtype)
257
+ else:
258
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
259
+ self.precompute_pos[(height, width)] = pos
260
+ return pos
261
+
262
+ def initialize_weights(self):
263
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
264
+ w = self.x_embedder.proj.weight.data
265
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
266
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
267
+
268
+ # Initialize label embedding table:
269
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
270
+
271
+ # Initialize timestep embedding MLP:
272
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
273
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
274
+
275
+ # Zero-out output layers:
276
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
277
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
278
+ nn.init.constant_(self.final_layer.linear.weight, 0)
279
+ nn.init.constant_(self.final_layer.linear.bias, 0)
280
+
281
+ def forward(self, x, t, y, masks=None):
282
+ if masks is None:
283
+ masks = [None, ]*self.num_blocks
284
+ if isinstance(masks, torch.Tensor):
285
+ masks = masks.unbind(0)
286
+ if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
287
+ masks = masks + [None]*(self.num_blocks-len(masks))
288
+
289
+ B, _, H, W = x.shape
290
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
291
+ x = self.x_embedder(x)
292
+ pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
293
+ B, L, C = x.shape
294
+ t = self.t_embedder(t.view(-1)).view(B, -1, C)
295
+ y = self.y_embedder(y).view(B, 1, C)
296
+ condition = nn.functional.silu(t + y)
297
+ for i, block in enumerate(self.blocks):
298
+ x = block(x, condition, pos, masks[i])
299
+ x = self.final_layer(x, condition)
300
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
301
+ return x
src/models/encoder.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import os
4
+ import timm
5
+ import transformers
6
+ import torch.nn as nn
7
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
8
+ from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
9
+ from torchvision.transforms import Normalize
10
+
11
+ class RandViT(nn.Module):
12
+ def __init__(self, model_id, weight_path:str=None):
13
+ super(RandViT, self).__init__()
14
+ self.encoder = timm.create_model(
15
+ model_id,
16
+ num_classes=0,
17
+ )
18
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
19
+ self.encoder.head = torch.nn.Identity()
20
+ self.patch_size = self.encoder.patch_embed.patch_size
21
+ self.shifts = nn.Parameter(torch.tensor([0.0
22
+ ]), requires_grad=False)
23
+ self.scales = nn.Parameter(torch.tensor([1.0
24
+ ]), requires_grad=False)
25
+
26
+ def forward(self, x):
27
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
28
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
29
+ b, c, h, w = x.shape
30
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
31
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
32
+ feature = feature.transpose(1, 2)
33
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
34
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
35
+ return feature
36
+
37
+ class MAE(nn.Module):
38
+ def __init__(self, model_id, weight_path:str):
39
+ super(MAE, self).__init__()
40
+ if os.path.isdir(weight_path):
41
+ weight_path = os.path.join(weight_path, "pytorch_model.bin")
42
+ self.encoder = timm.create_model(
43
+ model_id,
44
+ checkpoint_path=weight_path,
45
+ num_classes=0,
46
+ )
47
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
48
+ self.encoder.head = torch.nn.Identity()
49
+ self.patch_size = self.encoder.patch_embed.patch_size
50
+ self.shifts = nn.Parameter(torch.tensor([0.0
51
+ ]), requires_grad=False)
52
+ self.scales = nn.Parameter(torch.tensor([1.0
53
+ ]), requires_grad=False)
54
+
55
+ def forward(self, x):
56
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
57
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
58
+ b, c, h, w = x.shape
59
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
60
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
61
+ feature = feature.transpose(1, 2)
62
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
63
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
64
+ return feature
65
+
66
+ class DINO(nn.Module):
67
+ def __init__(self, model_id, weight_path:str):
68
+ super(DINO, self).__init__()
69
+ if os.path.isdir(weight_path):
70
+ weight_path = os.path.join(weight_path, "pytorch_model.bin")
71
+ self.encoder = timm.create_model(
72
+ model_id,
73
+ checkpoint_path=weight_path,
74
+ num_classes=0,
75
+ )
76
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
77
+ self.encoder.head = torch.nn.Identity()
78
+ self.patch_size = self.encoder.patch_embed.patch_size
79
+ self.shifts = nn.Parameter(torch.tensor([ 0.0,
80
+ ]), requires_grad=False)
81
+ self.scales = nn.Parameter(torch.tensor([ 1.0,
82
+ ]), requires_grad=False)
83
+
84
+ def forward(self, x):
85
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
86
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
87
+ b, c, h, w = x.shape
88
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
89
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
90
+ feature = feature.transpose(1, 2)
91
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
92
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
93
+ return feature
94
+
95
+ class CLIP(nn.Module):
96
+ def __init__(self, model_id, weight_path:str):
97
+ super(CLIP, self).__init__()
98
+ self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path)
99
+ self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size
100
+ self.shifts = nn.Parameter(torch.tensor([0.0,
101
+ ]), requires_grad=False)
102
+ self.scales = nn.Parameter(torch.tensor([1.0,
103
+ ]), requires_grad=False)
104
+
105
+ def forward(self, x):
106
+ x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x)
107
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
108
+ b, c, h, w = x.shape
109
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
110
+ feature = self.encoder(x)['last_hidden_state'][:, 1:]
111
+ feature = feature.transpose(1, 2)
112
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
113
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
114
+ return feature
115
+
116
+
117
+
118
+ class DINOv2(nn.Module):
119
+ def __init__(self, model_id, weight_path:str):
120
+ super(DINOv2, self).__init__()
121
+ self.encoder = transformers.Dinov2Model.from_pretrained(weight_path)
122
+ self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size
123
+
124
+ def forward(self, x):
125
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
126
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
127
+ b, c, h, w = x.shape
128
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
129
+ feature = self.encoder.forward(x)['last_hidden_state'][:, 1:]
130
+ feature = feature.transpose(1, 2)
131
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
132
+ return feature
src/models/vae.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import subprocess
3
+ import lightning.pytorch as pl
4
+
5
+ import logging
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+ def class_fn_from_str(class_str):
10
+ class_module, from_class = class_str.rsplit(".", 1)
11
+ class_module = __import__(class_module, fromlist=[from_class])
12
+ return getattr(class_module, from_class)
13
+
14
+
15
+ class BaseVAE(torch.nn.Module):
16
+ def __init__(self, scale=1.0, shift=0.0):
17
+ super().__init__()
18
+ self.model = torch.nn.Identity()
19
+ self.scale = scale
20
+ self.shift = shift
21
+
22
+ def encode(self, x):
23
+ return x/self.scale+self.shift
24
+
25
+ def decode(self, x):
26
+ return (x-self.shift)*self.scale
27
+
28
+
29
+ # very bad bugs with nearest sampling
30
+ class DownSampleVAE(BaseVAE):
31
+ def __init__(self, down_ratio, scale=1.0, shift=0.0):
32
+ super().__init__()
33
+ self.model = torch.nn.Identity()
34
+ self.scale = scale
35
+ self.shift = shift
36
+ self.down_ratio = down_ratio
37
+
38
+ def encode(self, x):
39
+ x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False)
40
+ return x/self.scale+self.shift
41
+
42
+ def decode(self, x):
43
+ x = (x-self.shift)*self.scale
44
+ x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False)
45
+ return x
46
+
47
+
48
+
49
+ class LatentVAE(BaseVAE):
50
+ def __init__(self, precompute=False, weight_path:str=None):
51
+ super().__init__()
52
+ self.precompute = precompute
53
+ self.model = None
54
+ self.weight_path = weight_path
55
+
56
+ from diffusers.models import AutoencoderKL
57
+ setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
58
+ self.scaling_factor = self.model.config.scaling_factor
59
+
60
+ @torch.no_grad()
61
+ def encode(self, x):
62
+ assert self.model is not None
63
+ if self.precompute:
64
+ return x.mul_(self.scaling_factor)
65
+ return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)
66
+
67
+ @torch.no_grad()
68
+ def decode(self, x):
69
+ assert self.model is not None
70
+ return self.model.decode(x.div_(self.scaling_factor)).sample
71
+
72
+
73
+ def uint82fp(x):
74
+ x = x.to(torch.float32)
75
+ x = (x - 127.5) / 127.5
76
+ return x
77
+
78
+ def fp2uint8(x):
79
+ x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
80
+ return x
81
+