RiverZ commited on
Commit
3b609b9
·
1 Parent(s): 0690a50
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ICEdit +0 -1
  2. app.py +1 -1
  3. icedit/diffusers/__init__.py +1014 -0
  4. icedit/diffusers/callbacks.py +209 -0
  5. icedit/diffusers/commands/__init__.py +27 -0
  6. icedit/diffusers/commands/diffusers_cli.py +43 -0
  7. icedit/diffusers/commands/env.py +180 -0
  8. icedit/diffusers/commands/fp16_safetensors.py +132 -0
  9. icedit/diffusers/configuration_utils.py +732 -0
  10. icedit/diffusers/dependency_versions_check.py +34 -0
  11. icedit/diffusers/dependency_versions_table.py +46 -0
  12. icedit/diffusers/experimental/__init__.py +1 -0
  13. icedit/diffusers/experimental/rl/__init__.py +1 -0
  14. icedit/diffusers/experimental/rl/value_guided_sampling.py +153 -0
  15. icedit/diffusers/image_processor.py +1314 -0
  16. icedit/diffusers/loaders/__init__.py +121 -0
  17. icedit/diffusers/loaders/ip_adapter.py +871 -0
  18. icedit/diffusers/loaders/lora_base.py +900 -0
  19. icedit/diffusers/loaders/lora_conversion_utils.py +1150 -0
  20. icedit/diffusers/loaders/lora_pipeline.py +0 -0
  21. icedit/diffusers/loaders/peft.py +750 -0
  22. icedit/diffusers/loaders/single_file.py +550 -0
  23. icedit/diffusers/loaders/single_file_model.py +385 -0
  24. icedit/diffusers/loaders/single_file_utils.py +0 -0
  25. icedit/diffusers/loaders/textual_inversion.py +580 -0
  26. icedit/diffusers/loaders/transformer_flux.py +181 -0
  27. icedit/diffusers/loaders/transformer_sd3.py +89 -0
  28. icedit/diffusers/loaders/unet.py +927 -0
  29. icedit/diffusers/loaders/unet_loader_utils.py +163 -0
  30. icedit/diffusers/loaders/utils.py +59 -0
  31. icedit/diffusers/models/__init__.py +172 -0
  32. icedit/diffusers/models/activations.py +178 -0
  33. icedit/diffusers/models/adapter.py +584 -0
  34. icedit/diffusers/models/attention.py +1252 -0
  35. icedit/diffusers/models/attention_flax.py +494 -0
  36. icedit/diffusers/models/attention_processor.py +0 -0
  37. icedit/diffusers/models/autoencoders/__init__.py +13 -0
  38. icedit/diffusers/models/autoencoders/autoencoder_asym_kl.py +184 -0
  39. icedit/diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  40. icedit/diffusers/models/autoencoders/autoencoder_kl.py +571 -0
  41. icedit/diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  42. icedit/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  43. icedit/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  44. icedit/diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  45. icedit/diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  46. icedit/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +394 -0
  47. icedit/diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  48. icedit/diffusers/models/autoencoders/autoencoder_tiny.py +350 -0
  49. icedit/diffusers/models/autoencoders/consistency_decoder_vae.py +460 -0
  50. icedit/diffusers/models/autoencoders/vae.py +995 -0
ICEdit DELETED
@@ -1 +0,0 @@
1
- Subproject commit 6e4f95590e5b56ca1313dc7f515a4d6bed49244c
 
 
app.py CHANGED
@@ -4,7 +4,7 @@ python scripts/gradio_demo.py
4
 
5
  import sys
6
  import os
7
- workspace_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "ICEdit/icedit"))
8
 
9
  if workspace_dir not in sys.path:
10
  sys.path.insert(0, workspace_dir)
 
4
 
5
  import sys
6
  import os
7
+ workspace_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "icedit"))
8
 
9
  if workspace_dir not in sys.path:
10
  sys.path.insert(0, workspace_dir)
icedit/diffusers/__init__.py ADDED
@@ -0,0 +1,1014 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.32.2"
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .utils import (
6
+ DIFFUSERS_SLOW_IMPORT,
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_flax_available,
10
+ is_k_diffusion_available,
11
+ is_librosa_available,
12
+ is_note_seq_available,
13
+ is_onnx_available,
14
+ is_scipy_available,
15
+ is_sentencepiece_available,
16
+ is_torch_available,
17
+ is_torchsde_available,
18
+ is_transformers_available,
19
+ )
20
+
21
+
22
+ # Lazy Import based on
23
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
24
+
25
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
26
+ # and is used to defer the actual importing for when the objects are requested.
27
+ # This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
28
+
29
+ _import_structure = {
30
+ "configuration_utils": ["ConfigMixin"],
31
+ "loaders": ["FromOriginalModelMixin"],
32
+ "models": [],
33
+ "pipelines": [],
34
+ "quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
35
+ "schedulers": [],
36
+ "utils": [
37
+ "OptionalDependencyNotAvailable",
38
+ "is_flax_available",
39
+ "is_inflect_available",
40
+ "is_invisible_watermark_available",
41
+ "is_k_diffusion_available",
42
+ "is_k_diffusion_version",
43
+ "is_librosa_available",
44
+ "is_note_seq_available",
45
+ "is_onnx_available",
46
+ "is_scipy_available",
47
+ "is_torch_available",
48
+ "is_torchsde_available",
49
+ "is_transformers_available",
50
+ "is_transformers_version",
51
+ "is_unidecode_available",
52
+ "logging",
53
+ ],
54
+ }
55
+
56
+ try:
57
+ if not is_onnx_available():
58
+ raise OptionalDependencyNotAvailable()
59
+ except OptionalDependencyNotAvailable:
60
+ from .utils import dummy_onnx_objects # noqa F403
61
+
62
+ _import_structure["utils.dummy_onnx_objects"] = [
63
+ name for name in dir(dummy_onnx_objects) if not name.startswith("_")
64
+ ]
65
+
66
+ else:
67
+ _import_structure["pipelines"].extend(["OnnxRuntimeModel"])
68
+
69
+ try:
70
+ if not is_torch_available():
71
+ raise OptionalDependencyNotAvailable()
72
+ except OptionalDependencyNotAvailable:
73
+ from .utils import dummy_pt_objects # noqa F403
74
+
75
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
76
+
77
+ else:
78
+ _import_structure["models"].extend(
79
+ [
80
+ "AllegroTransformer3DModel",
81
+ "AsymmetricAutoencoderKL",
82
+ "AuraFlowTransformer2DModel",
83
+ "AutoencoderDC",
84
+ "AutoencoderKL",
85
+ "AutoencoderKLAllegro",
86
+ "AutoencoderKLCogVideoX",
87
+ "AutoencoderKLHunyuanVideo",
88
+ "AutoencoderKLLTXVideo",
89
+ "AutoencoderKLMochi",
90
+ "AutoencoderKLTemporalDecoder",
91
+ "AutoencoderOobleck",
92
+ "AutoencoderTiny",
93
+ "CogVideoXTransformer3DModel",
94
+ "CogView3PlusTransformer2DModel",
95
+ "ConsistencyDecoderVAE",
96
+ "ControlNetModel",
97
+ "ControlNetUnionModel",
98
+ "ControlNetXSAdapter",
99
+ "DiTTransformer2DModel",
100
+ "FluxControlNetModel",
101
+ "FluxMultiControlNetModel",
102
+ "FluxTransformer2DModel",
103
+ "HunyuanDiT2DControlNetModel",
104
+ "HunyuanDiT2DModel",
105
+ "HunyuanDiT2DMultiControlNetModel",
106
+ "HunyuanVideoTransformer3DModel",
107
+ "I2VGenXLUNet",
108
+ "Kandinsky3UNet",
109
+ "LatteTransformer3DModel",
110
+ "LTXVideoTransformer3DModel",
111
+ "LuminaNextDiT2DModel",
112
+ "MochiTransformer3DModel",
113
+ "ModelMixin",
114
+ "MotionAdapter",
115
+ "MultiAdapter",
116
+ "MultiControlNetModel",
117
+ "PixArtTransformer2DModel",
118
+ "PriorTransformer",
119
+ "SanaTransformer2DModel",
120
+ "SD3ControlNetModel",
121
+ "SD3MultiControlNetModel",
122
+ "SD3Transformer2DModel",
123
+ "SparseControlNetModel",
124
+ "StableAudioDiTModel",
125
+ "StableCascadeUNet",
126
+ "T2IAdapter",
127
+ "T5FilmDecoder",
128
+ "Transformer2DModel",
129
+ "UNet1DModel",
130
+ "UNet2DConditionModel",
131
+ "UNet2DModel",
132
+ "UNet3DConditionModel",
133
+ "UNetControlNetXSModel",
134
+ "UNetMotionModel",
135
+ "UNetSpatioTemporalConditionModel",
136
+ "UVit2DModel",
137
+ "VQModel",
138
+ ]
139
+ )
140
+ _import_structure["optimization"] = [
141
+ "get_constant_schedule",
142
+ "get_constant_schedule_with_warmup",
143
+ "get_cosine_schedule_with_warmup",
144
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
145
+ "get_linear_schedule_with_warmup",
146
+ "get_polynomial_decay_schedule_with_warmup",
147
+ "get_scheduler",
148
+ ]
149
+ _import_structure["pipelines"].extend(
150
+ [
151
+ "AudioPipelineOutput",
152
+ "AutoPipelineForImage2Image",
153
+ "AutoPipelineForInpainting",
154
+ "AutoPipelineForText2Image",
155
+ "ConsistencyModelPipeline",
156
+ "DanceDiffusionPipeline",
157
+ "DDIMPipeline",
158
+ "DDPMPipeline",
159
+ "DiffusionPipeline",
160
+ "DiTPipeline",
161
+ "ImagePipelineOutput",
162
+ "KarrasVePipeline",
163
+ "LDMPipeline",
164
+ "LDMSuperResolutionPipeline",
165
+ "PNDMPipeline",
166
+ "RePaintPipeline",
167
+ "ScoreSdeVePipeline",
168
+ "StableDiffusionMixin",
169
+ ]
170
+ )
171
+ _import_structure["quantizers"] = ["DiffusersQuantizer"]
172
+ _import_structure["schedulers"].extend(
173
+ [
174
+ "AmusedScheduler",
175
+ "CMStochasticIterativeScheduler",
176
+ "CogVideoXDDIMScheduler",
177
+ "CogVideoXDPMScheduler",
178
+ "DDIMInverseScheduler",
179
+ "DDIMParallelScheduler",
180
+ "DDIMScheduler",
181
+ "DDPMParallelScheduler",
182
+ "DDPMScheduler",
183
+ "DDPMWuerstchenScheduler",
184
+ "DEISMultistepScheduler",
185
+ "DPMSolverMultistepInverseScheduler",
186
+ "DPMSolverMultistepScheduler",
187
+ "DPMSolverSinglestepScheduler",
188
+ "EDMDPMSolverMultistepScheduler",
189
+ "EDMEulerScheduler",
190
+ "EulerAncestralDiscreteScheduler",
191
+ "EulerDiscreteScheduler",
192
+ "FlowMatchEulerDiscreteScheduler",
193
+ "FlowMatchHeunDiscreteScheduler",
194
+ "HeunDiscreteScheduler",
195
+ "IPNDMScheduler",
196
+ "KarrasVeScheduler",
197
+ "KDPM2AncestralDiscreteScheduler",
198
+ "KDPM2DiscreteScheduler",
199
+ "LCMScheduler",
200
+ "PNDMScheduler",
201
+ "RePaintScheduler",
202
+ "SASolverScheduler",
203
+ "SchedulerMixin",
204
+ "ScoreSdeVeScheduler",
205
+ "TCDScheduler",
206
+ "UnCLIPScheduler",
207
+ "UniPCMultistepScheduler",
208
+ "VQDiffusionScheduler",
209
+ ]
210
+ )
211
+ _import_structure["training_utils"] = ["EMAModel"]
212
+
213
+ try:
214
+ if not (is_torch_available() and is_scipy_available()):
215
+ raise OptionalDependencyNotAvailable()
216
+ except OptionalDependencyNotAvailable:
217
+ from .utils import dummy_torch_and_scipy_objects # noqa F403
218
+
219
+ _import_structure["utils.dummy_torch_and_scipy_objects"] = [
220
+ name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
221
+ ]
222
+
223
+ else:
224
+ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
225
+
226
+ try:
227
+ if not (is_torch_available() and is_torchsde_available()):
228
+ raise OptionalDependencyNotAvailable()
229
+ except OptionalDependencyNotAvailable:
230
+ from .utils import dummy_torch_and_torchsde_objects # noqa F403
231
+
232
+ _import_structure["utils.dummy_torch_and_torchsde_objects"] = [
233
+ name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
234
+ ]
235
+
236
+ else:
237
+ _import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
238
+
239
+ try:
240
+ if not (is_torch_available() and is_transformers_available()):
241
+ raise OptionalDependencyNotAvailable()
242
+ except OptionalDependencyNotAvailable:
243
+ from .utils import dummy_torch_and_transformers_objects # noqa F403
244
+
245
+ _import_structure["utils.dummy_torch_and_transformers_objects"] = [
246
+ name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
247
+ ]
248
+
249
+ else:
250
+ _import_structure["pipelines"].extend(
251
+ [
252
+ "AllegroPipeline",
253
+ "AltDiffusionImg2ImgPipeline",
254
+ "AltDiffusionPipeline",
255
+ "AmusedImg2ImgPipeline",
256
+ "AmusedInpaintPipeline",
257
+ "AmusedPipeline",
258
+ "AnimateDiffControlNetPipeline",
259
+ "AnimateDiffPAGPipeline",
260
+ "AnimateDiffPipeline",
261
+ "AnimateDiffSDXLPipeline",
262
+ "AnimateDiffSparseControlNetPipeline",
263
+ "AnimateDiffVideoToVideoControlNetPipeline",
264
+ "AnimateDiffVideoToVideoPipeline",
265
+ "AudioLDM2Pipeline",
266
+ "AudioLDM2ProjectionModel",
267
+ "AudioLDM2UNet2DConditionModel",
268
+ "AudioLDMPipeline",
269
+ "AuraFlowPipeline",
270
+ "BlipDiffusionControlNetPipeline",
271
+ "BlipDiffusionPipeline",
272
+ "CLIPImageProjection",
273
+ "CogVideoXFunControlPipeline",
274
+ "CogVideoXImageToVideoPipeline",
275
+ "CogVideoXPipeline",
276
+ "CogVideoXVideoToVideoPipeline",
277
+ "CogView3PlusPipeline",
278
+ "CycleDiffusionPipeline",
279
+ "FluxControlImg2ImgPipeline",
280
+ "FluxControlInpaintPipeline",
281
+ "FluxControlNetImg2ImgPipeline",
282
+ "FluxControlNetInpaintPipeline",
283
+ "FluxControlNetPipeline",
284
+ "FluxControlPipeline",
285
+ "FluxFillPipeline",
286
+ "FluxImg2ImgPipeline",
287
+ "FluxInpaintPipeline",
288
+ "FluxPipeline",
289
+ "FluxPriorReduxPipeline",
290
+ "HunyuanDiTControlNetPipeline",
291
+ "HunyuanDiTPAGPipeline",
292
+ "HunyuanDiTPipeline",
293
+ "HunyuanVideoPipeline",
294
+ "I2VGenXLPipeline",
295
+ "IFImg2ImgPipeline",
296
+ "IFImg2ImgSuperResolutionPipeline",
297
+ "IFInpaintingPipeline",
298
+ "IFInpaintingSuperResolutionPipeline",
299
+ "IFPipeline",
300
+ "IFSuperResolutionPipeline",
301
+ "ImageTextPipelineOutput",
302
+ "Kandinsky3Img2ImgPipeline",
303
+ "Kandinsky3Pipeline",
304
+ "KandinskyCombinedPipeline",
305
+ "KandinskyImg2ImgCombinedPipeline",
306
+ "KandinskyImg2ImgPipeline",
307
+ "KandinskyInpaintCombinedPipeline",
308
+ "KandinskyInpaintPipeline",
309
+ "KandinskyPipeline",
310
+ "KandinskyPriorPipeline",
311
+ "KandinskyV22CombinedPipeline",
312
+ "KandinskyV22ControlnetImg2ImgPipeline",
313
+ "KandinskyV22ControlnetPipeline",
314
+ "KandinskyV22Img2ImgCombinedPipeline",
315
+ "KandinskyV22Img2ImgPipeline",
316
+ "KandinskyV22InpaintCombinedPipeline",
317
+ "KandinskyV22InpaintPipeline",
318
+ "KandinskyV22Pipeline",
319
+ "KandinskyV22PriorEmb2EmbPipeline",
320
+ "KandinskyV22PriorPipeline",
321
+ "LatentConsistencyModelImg2ImgPipeline",
322
+ "LatentConsistencyModelPipeline",
323
+ "LattePipeline",
324
+ "LDMTextToImagePipeline",
325
+ "LEditsPPPipelineStableDiffusion",
326
+ "LEditsPPPipelineStableDiffusionXL",
327
+ "LTXImageToVideoPipeline",
328
+ "LTXPipeline",
329
+ "LuminaText2ImgPipeline",
330
+ "MarigoldDepthPipeline",
331
+ "MarigoldNormalsPipeline",
332
+ "MochiPipeline",
333
+ "MusicLDMPipeline",
334
+ "PaintByExamplePipeline",
335
+ "PIAPipeline",
336
+ "PixArtAlphaPipeline",
337
+ "PixArtSigmaPAGPipeline",
338
+ "PixArtSigmaPipeline",
339
+ "ReduxImageEncoder",
340
+ "SanaPAGPipeline",
341
+ "SanaPipeline",
342
+ "SemanticStableDiffusionPipeline",
343
+ "ShapEImg2ImgPipeline",
344
+ "ShapEPipeline",
345
+ "StableAudioPipeline",
346
+ "StableAudioProjectionModel",
347
+ "StableCascadeCombinedPipeline",
348
+ "StableCascadeDecoderPipeline",
349
+ "StableCascadePriorPipeline",
350
+ "StableDiffusion3ControlNetInpaintingPipeline",
351
+ "StableDiffusion3ControlNetPipeline",
352
+ "StableDiffusion3Img2ImgPipeline",
353
+ "StableDiffusion3InpaintPipeline",
354
+ "StableDiffusion3PAGImg2ImgPipeline",
355
+ "StableDiffusion3PAGImg2ImgPipeline",
356
+ "StableDiffusion3PAGPipeline",
357
+ "StableDiffusion3Pipeline",
358
+ "StableDiffusionAdapterPipeline",
359
+ "StableDiffusionAttendAndExcitePipeline",
360
+ "StableDiffusionControlNetImg2ImgPipeline",
361
+ "StableDiffusionControlNetInpaintPipeline",
362
+ "StableDiffusionControlNetPAGInpaintPipeline",
363
+ "StableDiffusionControlNetPAGPipeline",
364
+ "StableDiffusionControlNetPipeline",
365
+ "StableDiffusionControlNetXSPipeline",
366
+ "StableDiffusionDepth2ImgPipeline",
367
+ "StableDiffusionDiffEditPipeline",
368
+ "StableDiffusionGLIGENPipeline",
369
+ "StableDiffusionGLIGENTextImagePipeline",
370
+ "StableDiffusionImageVariationPipeline",
371
+ "StableDiffusionImg2ImgPipeline",
372
+ "StableDiffusionInpaintPipeline",
373
+ "StableDiffusionInpaintPipelineLegacy",
374
+ "StableDiffusionInstructPix2PixPipeline",
375
+ "StableDiffusionLatentUpscalePipeline",
376
+ "StableDiffusionLDM3DPipeline",
377
+ "StableDiffusionModelEditingPipeline",
378
+ "StableDiffusionPAGImg2ImgPipeline",
379
+ "StableDiffusionPAGInpaintPipeline",
380
+ "StableDiffusionPAGPipeline",
381
+ "StableDiffusionPanoramaPipeline",
382
+ "StableDiffusionParadigmsPipeline",
383
+ "StableDiffusionPipeline",
384
+ "StableDiffusionPipelineSafe",
385
+ "StableDiffusionPix2PixZeroPipeline",
386
+ "StableDiffusionSAGPipeline",
387
+ "StableDiffusionUpscalePipeline",
388
+ "StableDiffusionXLAdapterPipeline",
389
+ "StableDiffusionXLControlNetImg2ImgPipeline",
390
+ "StableDiffusionXLControlNetInpaintPipeline",
391
+ "StableDiffusionXLControlNetPAGImg2ImgPipeline",
392
+ "StableDiffusionXLControlNetPAGPipeline",
393
+ "StableDiffusionXLControlNetPipeline",
394
+ "StableDiffusionXLControlNetUnionImg2ImgPipeline",
395
+ "StableDiffusionXLControlNetUnionInpaintPipeline",
396
+ "StableDiffusionXLControlNetUnionPipeline",
397
+ "StableDiffusionXLControlNetXSPipeline",
398
+ "StableDiffusionXLImg2ImgPipeline",
399
+ "StableDiffusionXLInpaintPipeline",
400
+ "StableDiffusionXLInstructPix2PixPipeline",
401
+ "StableDiffusionXLPAGImg2ImgPipeline",
402
+ "StableDiffusionXLPAGInpaintPipeline",
403
+ "StableDiffusionXLPAGPipeline",
404
+ "StableDiffusionXLPipeline",
405
+ "StableUnCLIPImg2ImgPipeline",
406
+ "StableUnCLIPPipeline",
407
+ "StableVideoDiffusionPipeline",
408
+ "TextToVideoSDPipeline",
409
+ "TextToVideoZeroPipeline",
410
+ "TextToVideoZeroSDXLPipeline",
411
+ "UnCLIPImageVariationPipeline",
412
+ "UnCLIPPipeline",
413
+ "UniDiffuserModel",
414
+ "UniDiffuserPipeline",
415
+ "UniDiffuserTextDecoder",
416
+ "VersatileDiffusionDualGuidedPipeline",
417
+ "VersatileDiffusionImageVariationPipeline",
418
+ "VersatileDiffusionPipeline",
419
+ "VersatileDiffusionTextToImagePipeline",
420
+ "VideoToVideoSDPipeline",
421
+ "VQDiffusionPipeline",
422
+ "WuerstchenCombinedPipeline",
423
+ "WuerstchenDecoderPipeline",
424
+ "WuerstchenPriorPipeline",
425
+ ]
426
+ )
427
+
428
+ try:
429
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
430
+ raise OptionalDependencyNotAvailable()
431
+ except OptionalDependencyNotAvailable:
432
+ from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
433
+
434
+ _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
435
+ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
436
+ ]
437
+
438
+ else:
439
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
440
+
441
+ try:
442
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
443
+ raise OptionalDependencyNotAvailable()
444
+ except OptionalDependencyNotAvailable:
445
+ from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
446
+
447
+ _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
448
+ name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
449
+ ]
450
+
451
+ else:
452
+ _import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
453
+
454
+ try:
455
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
456
+ raise OptionalDependencyNotAvailable()
457
+ except OptionalDependencyNotAvailable:
458
+ from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
459
+
460
+ _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
461
+ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
462
+ ]
463
+
464
+ else:
465
+ _import_structure["pipelines"].extend(
466
+ [
467
+ "OnnxStableDiffusionImg2ImgPipeline",
468
+ "OnnxStableDiffusionInpaintPipeline",
469
+ "OnnxStableDiffusionInpaintPipelineLegacy",
470
+ "OnnxStableDiffusionPipeline",
471
+ "OnnxStableDiffusionUpscalePipeline",
472
+ "StableDiffusionOnnxPipeline",
473
+ ]
474
+ )
475
+
476
+ try:
477
+ if not (is_torch_available() and is_librosa_available()):
478
+ raise OptionalDependencyNotAvailable()
479
+ except OptionalDependencyNotAvailable:
480
+ from .utils import dummy_torch_and_librosa_objects # noqa F403
481
+
482
+ _import_structure["utils.dummy_torch_and_librosa_objects"] = [
483
+ name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
484
+ ]
485
+
486
+ else:
487
+ _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
488
+
489
+ try:
490
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
491
+ raise OptionalDependencyNotAvailable()
492
+ except OptionalDependencyNotAvailable:
493
+ from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
494
+
495
+ _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
496
+ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
497
+ ]
498
+
499
+
500
+ else:
501
+ _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
502
+
503
+ try:
504
+ if not is_flax_available():
505
+ raise OptionalDependencyNotAvailable()
506
+ except OptionalDependencyNotAvailable:
507
+ from .utils import dummy_flax_objects # noqa F403
508
+
509
+ _import_structure["utils.dummy_flax_objects"] = [
510
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
511
+ ]
512
+
513
+
514
+ else:
515
+ _import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
516
+ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
517
+ _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
518
+ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
519
+ _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
520
+ _import_structure["schedulers"].extend(
521
+ [
522
+ "FlaxDDIMScheduler",
523
+ "FlaxDDPMScheduler",
524
+ "FlaxDPMSolverMultistepScheduler",
525
+ "FlaxEulerDiscreteScheduler",
526
+ "FlaxKarrasVeScheduler",
527
+ "FlaxLMSDiscreteScheduler",
528
+ "FlaxPNDMScheduler",
529
+ "FlaxSchedulerMixin",
530
+ "FlaxScoreSdeVeScheduler",
531
+ ]
532
+ )
533
+
534
+
535
+ try:
536
+ if not (is_flax_available() and is_transformers_available()):
537
+ raise OptionalDependencyNotAvailable()
538
+ except OptionalDependencyNotAvailable:
539
+ from .utils import dummy_flax_and_transformers_objects # noqa F403
540
+
541
+ _import_structure["utils.dummy_flax_and_transformers_objects"] = [
542
+ name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
543
+ ]
544
+
545
+
546
+ else:
547
+ _import_structure["pipelines"].extend(
548
+ [
549
+ "FlaxStableDiffusionControlNetPipeline",
550
+ "FlaxStableDiffusionImg2ImgPipeline",
551
+ "FlaxStableDiffusionInpaintPipeline",
552
+ "FlaxStableDiffusionPipeline",
553
+ "FlaxStableDiffusionXLPipeline",
554
+ ]
555
+ )
556
+
557
+ try:
558
+ if not (is_note_seq_available()):
559
+ raise OptionalDependencyNotAvailable()
560
+ except OptionalDependencyNotAvailable:
561
+ from .utils import dummy_note_seq_objects # noqa F403
562
+
563
+ _import_structure["utils.dummy_note_seq_objects"] = [
564
+ name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
565
+ ]
566
+
567
+
568
+ else:
569
+ _import_structure["pipelines"].extend(["MidiProcessor"])
570
+
571
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
572
+ from .configuration_utils import ConfigMixin
573
+ from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
574
+
575
+ try:
576
+ if not is_onnx_available():
577
+ raise OptionalDependencyNotAvailable()
578
+ except OptionalDependencyNotAvailable:
579
+ from .utils.dummy_onnx_objects import * # noqa F403
580
+ else:
581
+ from .pipelines import OnnxRuntimeModel
582
+
583
+ try:
584
+ if not is_torch_available():
585
+ raise OptionalDependencyNotAvailable()
586
+ except OptionalDependencyNotAvailable:
587
+ from .utils.dummy_pt_objects import * # noqa F403
588
+ else:
589
+ from .models import (
590
+ AllegroTransformer3DModel,
591
+ AsymmetricAutoencoderKL,
592
+ AuraFlowTransformer2DModel,
593
+ AutoencoderDC,
594
+ AutoencoderKL,
595
+ AutoencoderKLAllegro,
596
+ AutoencoderKLCogVideoX,
597
+ AutoencoderKLHunyuanVideo,
598
+ AutoencoderKLLTXVideo,
599
+ AutoencoderKLMochi,
600
+ AutoencoderKLTemporalDecoder,
601
+ AutoencoderOobleck,
602
+ AutoencoderTiny,
603
+ CogVideoXTransformer3DModel,
604
+ CogView3PlusTransformer2DModel,
605
+ ConsistencyDecoderVAE,
606
+ ControlNetModel,
607
+ ControlNetUnionModel,
608
+ ControlNetXSAdapter,
609
+ DiTTransformer2DModel,
610
+ FluxControlNetModel,
611
+ FluxMultiControlNetModel,
612
+ FluxTransformer2DModel,
613
+ HunyuanDiT2DControlNetModel,
614
+ HunyuanDiT2DModel,
615
+ HunyuanDiT2DMultiControlNetModel,
616
+ HunyuanVideoTransformer3DModel,
617
+ I2VGenXLUNet,
618
+ Kandinsky3UNet,
619
+ LatteTransformer3DModel,
620
+ LTXVideoTransformer3DModel,
621
+ LuminaNextDiT2DModel,
622
+ MochiTransformer3DModel,
623
+ ModelMixin,
624
+ MotionAdapter,
625
+ MultiAdapter,
626
+ MultiControlNetModel,
627
+ PixArtTransformer2DModel,
628
+ PriorTransformer,
629
+ SanaTransformer2DModel,
630
+ SD3ControlNetModel,
631
+ SD3MultiControlNetModel,
632
+ SD3Transformer2DModel,
633
+ SparseControlNetModel,
634
+ StableAudioDiTModel,
635
+ T2IAdapter,
636
+ T5FilmDecoder,
637
+ Transformer2DModel,
638
+ UNet1DModel,
639
+ UNet2DConditionModel,
640
+ UNet2DModel,
641
+ UNet3DConditionModel,
642
+ UNetControlNetXSModel,
643
+ UNetMotionModel,
644
+ UNetSpatioTemporalConditionModel,
645
+ UVit2DModel,
646
+ VQModel,
647
+ )
648
+ from .optimization import (
649
+ get_constant_schedule,
650
+ get_constant_schedule_with_warmup,
651
+ get_cosine_schedule_with_warmup,
652
+ get_cosine_with_hard_restarts_schedule_with_warmup,
653
+ get_linear_schedule_with_warmup,
654
+ get_polynomial_decay_schedule_with_warmup,
655
+ get_scheduler,
656
+ )
657
+ from .pipelines import (
658
+ AudioPipelineOutput,
659
+ AutoPipelineForImage2Image,
660
+ AutoPipelineForInpainting,
661
+ AutoPipelineForText2Image,
662
+ BlipDiffusionControlNetPipeline,
663
+ BlipDiffusionPipeline,
664
+ CLIPImageProjection,
665
+ ConsistencyModelPipeline,
666
+ DanceDiffusionPipeline,
667
+ DDIMPipeline,
668
+ DDPMPipeline,
669
+ DiffusionPipeline,
670
+ DiTPipeline,
671
+ ImagePipelineOutput,
672
+ KarrasVePipeline,
673
+ LDMPipeline,
674
+ LDMSuperResolutionPipeline,
675
+ PNDMPipeline,
676
+ RePaintPipeline,
677
+ ScoreSdeVePipeline,
678
+ StableDiffusionMixin,
679
+ )
680
+ from .quantizers import DiffusersQuantizer
681
+ from .schedulers import (
682
+ AmusedScheduler,
683
+ CMStochasticIterativeScheduler,
684
+ CogVideoXDDIMScheduler,
685
+ CogVideoXDPMScheduler,
686
+ DDIMInverseScheduler,
687
+ DDIMParallelScheduler,
688
+ DDIMScheduler,
689
+ DDPMParallelScheduler,
690
+ DDPMScheduler,
691
+ DDPMWuerstchenScheduler,
692
+ DEISMultistepScheduler,
693
+ DPMSolverMultistepInverseScheduler,
694
+ DPMSolverMultistepScheduler,
695
+ DPMSolverSinglestepScheduler,
696
+ EDMDPMSolverMultistepScheduler,
697
+ EDMEulerScheduler,
698
+ EulerAncestralDiscreteScheduler,
699
+ EulerDiscreteScheduler,
700
+ FlowMatchEulerDiscreteScheduler,
701
+ FlowMatchHeunDiscreteScheduler,
702
+ HeunDiscreteScheduler,
703
+ IPNDMScheduler,
704
+ KarrasVeScheduler,
705
+ KDPM2AncestralDiscreteScheduler,
706
+ KDPM2DiscreteScheduler,
707
+ LCMScheduler,
708
+ PNDMScheduler,
709
+ RePaintScheduler,
710
+ SASolverScheduler,
711
+ SchedulerMixin,
712
+ ScoreSdeVeScheduler,
713
+ TCDScheduler,
714
+ UnCLIPScheduler,
715
+ UniPCMultistepScheduler,
716
+ VQDiffusionScheduler,
717
+ )
718
+ from .training_utils import EMAModel
719
+
720
+ try:
721
+ if not (is_torch_available() and is_scipy_available()):
722
+ raise OptionalDependencyNotAvailable()
723
+ except OptionalDependencyNotAvailable:
724
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
725
+ else:
726
+ from .schedulers import LMSDiscreteScheduler
727
+
728
+ try:
729
+ if not (is_torch_available() and is_torchsde_available()):
730
+ raise OptionalDependencyNotAvailable()
731
+ except OptionalDependencyNotAvailable:
732
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
733
+ else:
734
+ from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler
735
+
736
+ try:
737
+ if not (is_torch_available() and is_transformers_available()):
738
+ raise OptionalDependencyNotAvailable()
739
+ except OptionalDependencyNotAvailable:
740
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
741
+ else:
742
+ from .pipelines import (
743
+ AllegroPipeline,
744
+ AltDiffusionImg2ImgPipeline,
745
+ AltDiffusionPipeline,
746
+ AmusedImg2ImgPipeline,
747
+ AmusedInpaintPipeline,
748
+ AmusedPipeline,
749
+ AnimateDiffControlNetPipeline,
750
+ AnimateDiffPAGPipeline,
751
+ AnimateDiffPipeline,
752
+ AnimateDiffSDXLPipeline,
753
+ AnimateDiffSparseControlNetPipeline,
754
+ AnimateDiffVideoToVideoControlNetPipeline,
755
+ AnimateDiffVideoToVideoPipeline,
756
+ AudioLDM2Pipeline,
757
+ AudioLDM2ProjectionModel,
758
+ AudioLDM2UNet2DConditionModel,
759
+ AudioLDMPipeline,
760
+ AuraFlowPipeline,
761
+ CLIPImageProjection,
762
+ CogVideoXFunControlPipeline,
763
+ CogVideoXImageToVideoPipeline,
764
+ CogVideoXPipeline,
765
+ CogVideoXVideoToVideoPipeline,
766
+ CogView3PlusPipeline,
767
+ CycleDiffusionPipeline,
768
+ FluxControlImg2ImgPipeline,
769
+ FluxControlInpaintPipeline,
770
+ FluxControlNetImg2ImgPipeline,
771
+ FluxControlNetInpaintPipeline,
772
+ FluxControlNetPipeline,
773
+ FluxControlPipeline,
774
+ FluxFillPipeline,
775
+ FluxImg2ImgPipeline,
776
+ FluxInpaintPipeline,
777
+ FluxPipeline,
778
+ FluxPriorReduxPipeline,
779
+ HunyuanDiTControlNetPipeline,
780
+ HunyuanDiTPAGPipeline,
781
+ HunyuanDiTPipeline,
782
+ HunyuanVideoPipeline,
783
+ I2VGenXLPipeline,
784
+ IFImg2ImgPipeline,
785
+ IFImg2ImgSuperResolutionPipeline,
786
+ IFInpaintingPipeline,
787
+ IFInpaintingSuperResolutionPipeline,
788
+ IFPipeline,
789
+ IFSuperResolutionPipeline,
790
+ ImageTextPipelineOutput,
791
+ Kandinsky3Img2ImgPipeline,
792
+ Kandinsky3Pipeline,
793
+ KandinskyCombinedPipeline,
794
+ KandinskyImg2ImgCombinedPipeline,
795
+ KandinskyImg2ImgPipeline,
796
+ KandinskyInpaintCombinedPipeline,
797
+ KandinskyInpaintPipeline,
798
+ KandinskyPipeline,
799
+ KandinskyPriorPipeline,
800
+ KandinskyV22CombinedPipeline,
801
+ KandinskyV22ControlnetImg2ImgPipeline,
802
+ KandinskyV22ControlnetPipeline,
803
+ KandinskyV22Img2ImgCombinedPipeline,
804
+ KandinskyV22Img2ImgPipeline,
805
+ KandinskyV22InpaintCombinedPipeline,
806
+ KandinskyV22InpaintPipeline,
807
+ KandinskyV22Pipeline,
808
+ KandinskyV22PriorEmb2EmbPipeline,
809
+ KandinskyV22PriorPipeline,
810
+ LatentConsistencyModelImg2ImgPipeline,
811
+ LatentConsistencyModelPipeline,
812
+ LattePipeline,
813
+ LDMTextToImagePipeline,
814
+ LEditsPPPipelineStableDiffusion,
815
+ LEditsPPPipelineStableDiffusionXL,
816
+ LTXImageToVideoPipeline,
817
+ LTXPipeline,
818
+ LuminaText2ImgPipeline,
819
+ MarigoldDepthPipeline,
820
+ MarigoldNormalsPipeline,
821
+ MochiPipeline,
822
+ MusicLDMPipeline,
823
+ PaintByExamplePipeline,
824
+ PIAPipeline,
825
+ PixArtAlphaPipeline,
826
+ PixArtSigmaPAGPipeline,
827
+ PixArtSigmaPipeline,
828
+ ReduxImageEncoder,
829
+ SanaPAGPipeline,
830
+ SanaPipeline,
831
+ SemanticStableDiffusionPipeline,
832
+ ShapEImg2ImgPipeline,
833
+ ShapEPipeline,
834
+ StableAudioPipeline,
835
+ StableAudioProjectionModel,
836
+ StableCascadeCombinedPipeline,
837
+ StableCascadeDecoderPipeline,
838
+ StableCascadePriorPipeline,
839
+ StableDiffusion3ControlNetPipeline,
840
+ StableDiffusion3Img2ImgPipeline,
841
+ StableDiffusion3InpaintPipeline,
842
+ StableDiffusion3PAGImg2ImgPipeline,
843
+ StableDiffusion3PAGPipeline,
844
+ StableDiffusion3Pipeline,
845
+ StableDiffusionAdapterPipeline,
846
+ StableDiffusionAttendAndExcitePipeline,
847
+ StableDiffusionControlNetImg2ImgPipeline,
848
+ StableDiffusionControlNetInpaintPipeline,
849
+ StableDiffusionControlNetPAGInpaintPipeline,
850
+ StableDiffusionControlNetPAGPipeline,
851
+ StableDiffusionControlNetPipeline,
852
+ StableDiffusionControlNetXSPipeline,
853
+ StableDiffusionDepth2ImgPipeline,
854
+ StableDiffusionDiffEditPipeline,
855
+ StableDiffusionGLIGENPipeline,
856
+ StableDiffusionGLIGENTextImagePipeline,
857
+ StableDiffusionImageVariationPipeline,
858
+ StableDiffusionImg2ImgPipeline,
859
+ StableDiffusionInpaintPipeline,
860
+ StableDiffusionInpaintPipelineLegacy,
861
+ StableDiffusionInstructPix2PixPipeline,
862
+ StableDiffusionLatentUpscalePipeline,
863
+ StableDiffusionLDM3DPipeline,
864
+ StableDiffusionModelEditingPipeline,
865
+ StableDiffusionPAGImg2ImgPipeline,
866
+ StableDiffusionPAGInpaintPipeline,
867
+ StableDiffusionPAGPipeline,
868
+ StableDiffusionPanoramaPipeline,
869
+ StableDiffusionParadigmsPipeline,
870
+ StableDiffusionPipeline,
871
+ StableDiffusionPipelineSafe,
872
+ StableDiffusionPix2PixZeroPipeline,
873
+ StableDiffusionSAGPipeline,
874
+ StableDiffusionUpscalePipeline,
875
+ StableDiffusionXLAdapterPipeline,
876
+ StableDiffusionXLControlNetImg2ImgPipeline,
877
+ StableDiffusionXLControlNetInpaintPipeline,
878
+ StableDiffusionXLControlNetPAGImg2ImgPipeline,
879
+ StableDiffusionXLControlNetPAGPipeline,
880
+ StableDiffusionXLControlNetPipeline,
881
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
882
+ StableDiffusionXLControlNetUnionInpaintPipeline,
883
+ StableDiffusionXLControlNetUnionPipeline,
884
+ StableDiffusionXLControlNetXSPipeline,
885
+ StableDiffusionXLImg2ImgPipeline,
886
+ StableDiffusionXLInpaintPipeline,
887
+ StableDiffusionXLInstructPix2PixPipeline,
888
+ StableDiffusionXLPAGImg2ImgPipeline,
889
+ StableDiffusionXLPAGInpaintPipeline,
890
+ StableDiffusionXLPAGPipeline,
891
+ StableDiffusionXLPipeline,
892
+ StableUnCLIPImg2ImgPipeline,
893
+ StableUnCLIPPipeline,
894
+ StableVideoDiffusionPipeline,
895
+ TextToVideoSDPipeline,
896
+ TextToVideoZeroPipeline,
897
+ TextToVideoZeroSDXLPipeline,
898
+ UnCLIPImageVariationPipeline,
899
+ UnCLIPPipeline,
900
+ UniDiffuserModel,
901
+ UniDiffuserPipeline,
902
+ UniDiffuserTextDecoder,
903
+ VersatileDiffusionDualGuidedPipeline,
904
+ VersatileDiffusionImageVariationPipeline,
905
+ VersatileDiffusionPipeline,
906
+ VersatileDiffusionTextToImagePipeline,
907
+ VideoToVideoSDPipeline,
908
+ VQDiffusionPipeline,
909
+ WuerstchenCombinedPipeline,
910
+ WuerstchenDecoderPipeline,
911
+ WuerstchenPriorPipeline,
912
+ )
913
+
914
+ try:
915
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
916
+ raise OptionalDependencyNotAvailable()
917
+ except OptionalDependencyNotAvailable:
918
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
919
+ else:
920
+ from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
921
+
922
+ try:
923
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
924
+ raise OptionalDependencyNotAvailable()
925
+ except OptionalDependencyNotAvailable:
926
+ from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
927
+ else:
928
+ from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
929
+ try:
930
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
931
+ raise OptionalDependencyNotAvailable()
932
+ except OptionalDependencyNotAvailable:
933
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
934
+ else:
935
+ from .pipelines import (
936
+ OnnxStableDiffusionImg2ImgPipeline,
937
+ OnnxStableDiffusionInpaintPipeline,
938
+ OnnxStableDiffusionInpaintPipelineLegacy,
939
+ OnnxStableDiffusionPipeline,
940
+ OnnxStableDiffusionUpscalePipeline,
941
+ StableDiffusionOnnxPipeline,
942
+ )
943
+
944
+ try:
945
+ if not (is_torch_available() and is_librosa_available()):
946
+ raise OptionalDependencyNotAvailable()
947
+ except OptionalDependencyNotAvailable:
948
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
949
+ else:
950
+ from .pipelines import AudioDiffusionPipeline, Mel
951
+
952
+ try:
953
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
954
+ raise OptionalDependencyNotAvailable()
955
+ except OptionalDependencyNotAvailable:
956
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
957
+ else:
958
+ from .pipelines import SpectrogramDiffusionPipeline
959
+
960
+ try:
961
+ if not is_flax_available():
962
+ raise OptionalDependencyNotAvailable()
963
+ except OptionalDependencyNotAvailable:
964
+ from .utils.dummy_flax_objects import * # noqa F403
965
+ else:
966
+ from .models.controlnets.controlnet_flax import FlaxControlNetModel
967
+ from .models.modeling_flax_utils import FlaxModelMixin
968
+ from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
969
+ from .models.vae_flax import FlaxAutoencoderKL
970
+ from .pipelines import FlaxDiffusionPipeline
971
+ from .schedulers import (
972
+ FlaxDDIMScheduler,
973
+ FlaxDDPMScheduler,
974
+ FlaxDPMSolverMultistepScheduler,
975
+ FlaxEulerDiscreteScheduler,
976
+ FlaxKarrasVeScheduler,
977
+ FlaxLMSDiscreteScheduler,
978
+ FlaxPNDMScheduler,
979
+ FlaxSchedulerMixin,
980
+ FlaxScoreSdeVeScheduler,
981
+ )
982
+
983
+ try:
984
+ if not (is_flax_available() and is_transformers_available()):
985
+ raise OptionalDependencyNotAvailable()
986
+ except OptionalDependencyNotAvailable:
987
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
988
+ else:
989
+ from .pipelines import (
990
+ FlaxStableDiffusionControlNetPipeline,
991
+ FlaxStableDiffusionImg2ImgPipeline,
992
+ FlaxStableDiffusionInpaintPipeline,
993
+ FlaxStableDiffusionPipeline,
994
+ FlaxStableDiffusionXLPipeline,
995
+ )
996
+
997
+ try:
998
+ if not (is_note_seq_available()):
999
+ raise OptionalDependencyNotAvailable()
1000
+ except OptionalDependencyNotAvailable:
1001
+ from .utils.dummy_note_seq_objects import * # noqa F403
1002
+ else:
1003
+ from .pipelines import MidiProcessor
1004
+
1005
+ else:
1006
+ import sys
1007
+
1008
+ sys.modules[__name__] = _LazyModule(
1009
+ __name__,
1010
+ globals()["__file__"],
1011
+ _import_structure,
1012
+ module_spec=__spec__,
1013
+ extra_objects={"__version__": __version__},
1014
+ )
icedit/diffusers/callbacks.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from .configuration_utils import ConfigMixin, register_to_config
4
+ from .utils import CONFIG_NAME
5
+
6
+
7
+ class PipelineCallback(ConfigMixin):
8
+ """
9
+ Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
10
+ custom callbacks and ensures that all callbacks have a consistent interface.
11
+
12
+ Please implement the following:
13
+ `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
14
+ include
15
+ variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
16
+ `callback_fn`: This method defines the core functionality of your callback.
17
+ """
18
+
19
+ config_name = CONFIG_NAME
20
+
21
+ @register_to_config
22
+ def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
23
+ super().__init__()
24
+
25
+ if (cutoff_step_ratio is None and cutoff_step_index is None) or (
26
+ cutoff_step_ratio is not None and cutoff_step_index is not None
27
+ ):
28
+ raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
29
+
30
+ if cutoff_step_ratio is not None and (
31
+ not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
32
+ ):
33
+ raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
34
+
35
+ @property
36
+ def tensor_inputs(self) -> List[str]:
37
+ raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
38
+
39
+ def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
40
+ raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
41
+
42
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
43
+ return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
44
+
45
+
46
+ class MultiPipelineCallbacks:
47
+ """
48
+ This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
49
+ provides a unified interface for calling all of them.
50
+ """
51
+
52
+ def __init__(self, callbacks: List[PipelineCallback]):
53
+ self.callbacks = callbacks
54
+
55
+ @property
56
+ def tensor_inputs(self) -> List[str]:
57
+ return [input for callback in self.callbacks for input in callback.tensor_inputs]
58
+
59
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
60
+ """
61
+ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
62
+ """
63
+ for callback in self.callbacks:
64
+ callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
65
+
66
+ return callback_kwargs
67
+
68
+
69
+ class SDCFGCutoffCallback(PipelineCallback):
70
+ """
71
+ Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
72
+ `cutoff_step_index`), this callback will disable the CFG.
73
+
74
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
75
+ """
76
+
77
+ tensor_inputs = ["prompt_embeds"]
78
+
79
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
80
+ cutoff_step_ratio = self.config.cutoff_step_ratio
81
+ cutoff_step_index = self.config.cutoff_step_index
82
+
83
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
84
+ cutoff_step = (
85
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
86
+ )
87
+
88
+ if step_index == cutoff_step:
89
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
90
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
91
+
92
+ pipeline._guidance_scale = 0.0
93
+
94
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
95
+ return callback_kwargs
96
+
97
+
98
+ class SDXLCFGCutoffCallback(PipelineCallback):
99
+ """
100
+ Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
101
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
102
+
103
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104
+ """
105
+
106
+ tensor_inputs = [
107
+ "prompt_embeds",
108
+ "add_text_embeds",
109
+ "add_time_ids",
110
+ ]
111
+
112
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
113
+ cutoff_step_ratio = self.config.cutoff_step_ratio
114
+ cutoff_step_index = self.config.cutoff_step_index
115
+
116
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
117
+ cutoff_step = (
118
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
119
+ )
120
+
121
+ if step_index == cutoff_step:
122
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
123
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
124
+
125
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
126
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
127
+
128
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
129
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
130
+
131
+ pipeline._guidance_scale = 0.0
132
+
133
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
134
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
135
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
136
+
137
+ return callback_kwargs
138
+
139
+
140
+ class SDXLControlnetCFGCutoffCallback(PipelineCallback):
141
+ """
142
+ Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
143
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
144
+
145
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
146
+ """
147
+
148
+ tensor_inputs = [
149
+ "prompt_embeds",
150
+ "add_text_embeds",
151
+ "add_time_ids",
152
+ "image",
153
+ ]
154
+
155
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
156
+ cutoff_step_ratio = self.config.cutoff_step_ratio
157
+ cutoff_step_index = self.config.cutoff_step_index
158
+
159
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
160
+ cutoff_step = (
161
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
162
+ )
163
+
164
+ if step_index == cutoff_step:
165
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
166
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
167
+
168
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
169
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
170
+
171
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
172
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
173
+
174
+ # For Controlnet
175
+ image = callback_kwargs[self.tensor_inputs[3]]
176
+ image = image[-1:]
177
+
178
+ pipeline._guidance_scale = 0.0
179
+
180
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
181
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
182
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
183
+ callback_kwargs[self.tensor_inputs[3]] = image
184
+
185
+ return callback_kwargs
186
+
187
+
188
+ class IPAdapterScaleCutoffCallback(PipelineCallback):
189
+ """
190
+ Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
191
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
192
+
193
+ Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
194
+ """
195
+
196
+ tensor_inputs = []
197
+
198
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
199
+ cutoff_step_ratio = self.config.cutoff_step_ratio
200
+ cutoff_step_index = self.config.cutoff_step_index
201
+
202
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
203
+ cutoff_step = (
204
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
205
+ )
206
+
207
+ if step_index == cutoff_step:
208
+ pipeline.set_ip_adapter_scale(0.0)
209
+ return callback_kwargs
icedit/diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
icedit/diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+ from .fp16_safetensors import FP16SafetensorsCommand
20
+
21
+
22
+ def main():
23
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
24
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
25
+
26
+ # Register commands
27
+ EnvironmentCommand.register_subcommand(commands_parser)
28
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
29
+
30
+ # Let's go
31
+ args = parser.parse_args()
32
+
33
+ if not hasattr(args, "func"):
34
+ parser.print_help()
35
+ exit(1)
36
+
37
+ # Run
38
+ service = args.func(args)
39
+ service.run()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
icedit/diffusers/commands/env.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ import subprocess
17
+ from argparse import ArgumentParser
18
+
19
+ import huggingface_hub
20
+
21
+ from .. import __version__ as version
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_bitsandbytes_available,
25
+ is_flax_available,
26
+ is_google_colab,
27
+ is_peft_available,
28
+ is_safetensors_available,
29
+ is_torch_available,
30
+ is_transformers_available,
31
+ is_xformers_available,
32
+ )
33
+ from . import BaseDiffusersCLICommand
34
+
35
+
36
+ def info_command_factory(_):
37
+ return EnvironmentCommand()
38
+
39
+
40
+ class EnvironmentCommand(BaseDiffusersCLICommand):
41
+ @staticmethod
42
+ def register_subcommand(parser: ArgumentParser) -> None:
43
+ download_parser = parser.add_parser("env")
44
+ download_parser.set_defaults(func=info_command_factory)
45
+
46
+ def run(self) -> dict:
47
+ hub_version = huggingface_hub.__version__
48
+
49
+ safetensors_version = "not installed"
50
+ if is_safetensors_available():
51
+ import safetensors
52
+
53
+ safetensors_version = safetensors.__version__
54
+
55
+ pt_version = "not installed"
56
+ pt_cuda_available = "NA"
57
+ if is_torch_available():
58
+ import torch
59
+
60
+ pt_version = torch.__version__
61
+ pt_cuda_available = torch.cuda.is_available()
62
+
63
+ flax_version = "not installed"
64
+ jax_version = "not installed"
65
+ jaxlib_version = "not installed"
66
+ jax_backend = "NA"
67
+ if is_flax_available():
68
+ import flax
69
+ import jax
70
+ import jaxlib
71
+
72
+ flax_version = flax.__version__
73
+ jax_version = jax.__version__
74
+ jaxlib_version = jaxlib.__version__
75
+ jax_backend = jax.lib.xla_bridge.get_backend().platform
76
+
77
+ transformers_version = "not installed"
78
+ if is_transformers_available():
79
+ import transformers
80
+
81
+ transformers_version = transformers.__version__
82
+
83
+ accelerate_version = "not installed"
84
+ if is_accelerate_available():
85
+ import accelerate
86
+
87
+ accelerate_version = accelerate.__version__
88
+
89
+ peft_version = "not installed"
90
+ if is_peft_available():
91
+ import peft
92
+
93
+ peft_version = peft.__version__
94
+
95
+ bitsandbytes_version = "not installed"
96
+ if is_bitsandbytes_available():
97
+ import bitsandbytes
98
+
99
+ bitsandbytes_version = bitsandbytes.__version__
100
+
101
+ xformers_version = "not installed"
102
+ if is_xformers_available():
103
+ import xformers
104
+
105
+ xformers_version = xformers.__version__
106
+
107
+ platform_info = platform.platform()
108
+
109
+ is_google_colab_str = "Yes" if is_google_colab() else "No"
110
+
111
+ accelerator = "NA"
112
+ if platform.system() in {"Linux", "Windows"}:
113
+ try:
114
+ sp = subprocess.Popen(
115
+ ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
116
+ stdout=subprocess.PIPE,
117
+ stderr=subprocess.PIPE,
118
+ )
119
+ out_str, _ = sp.communicate()
120
+ out_str = out_str.decode("utf-8")
121
+
122
+ if len(out_str) > 0:
123
+ accelerator = out_str.strip()
124
+ except FileNotFoundError:
125
+ pass
126
+ elif platform.system() == "Darwin": # Mac OS
127
+ try:
128
+ sp = subprocess.Popen(
129
+ ["system_profiler", "SPDisplaysDataType"],
130
+ stdout=subprocess.PIPE,
131
+ stderr=subprocess.PIPE,
132
+ )
133
+ out_str, _ = sp.communicate()
134
+ out_str = out_str.decode("utf-8")
135
+
136
+ start = out_str.find("Chipset Model:")
137
+ if start != -1:
138
+ start += len("Chipset Model:")
139
+ end = out_str.find("\n", start)
140
+ accelerator = out_str[start:end].strip()
141
+
142
+ start = out_str.find("VRAM (Total):")
143
+ if start != -1:
144
+ start += len("VRAM (Total):")
145
+ end = out_str.find("\n", start)
146
+ accelerator += " VRAM: " + out_str[start:end].strip()
147
+ except FileNotFoundError:
148
+ pass
149
+ else:
150
+ print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
151
+
152
+ info = {
153
+ "🤗 Diffusers version": version,
154
+ "Platform": platform_info,
155
+ "Running on Google Colab?": is_google_colab_str,
156
+ "Python version": platform.python_version(),
157
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
158
+ "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
159
+ "Jax version": jax_version,
160
+ "JaxLib version": jaxlib_version,
161
+ "Huggingface_hub version": hub_version,
162
+ "Transformers version": transformers_version,
163
+ "Accelerate version": accelerate_version,
164
+ "PEFT version": peft_version,
165
+ "Bitsandbytes version": bitsandbytes_version,
166
+ "Safetensors version": safetensors_version,
167
+ "xFormers version": xformers_version,
168
+ "Accelerator": accelerator,
169
+ "Using GPU in script?": "<fill in>",
170
+ "Using distributed or parallel set-up in script?": "<fill in>",
171
+ }
172
+
173
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
174
+ print(self.format_dict(info))
175
+
176
+ return info
177
+
178
+ @staticmethod
179
+ def format_dict(d: dict) -> str:
180
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
icedit/diffusers/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ import warnings
23
+ from argparse import ArgumentParser, Namespace
24
+ from importlib import import_module
25
+
26
+ import huggingface_hub
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from packaging import version
30
+
31
+ from ..utils import logging
32
+ from . import BaseDiffusersCLICommand
33
+
34
+
35
+ def conversion_command_factory(args: Namespace):
36
+ if args.use_auth_token:
37
+ warnings.warn(
38
+ "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
39
+ " handled automatically if user is logged in."
40
+ )
41
+ return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
42
+
43
+
44
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
45
+ @staticmethod
46
+ def register_subcommand(parser: ArgumentParser):
47
+ conversion_parser = parser.add_parser("fp16_safetensors")
48
+ conversion_parser.add_argument(
49
+ "--ckpt_id",
50
+ type=str,
51
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
52
+ )
53
+ conversion_parser.add_argument(
54
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
55
+ )
56
+ conversion_parser.add_argument(
57
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
58
+ )
59
+ conversion_parser.add_argument(
60
+ "--use_auth_token",
61
+ action="store_true",
62
+ help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
63
+ )
64
+ conversion_parser.set_defaults(func=conversion_command_factory)
65
+
66
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
67
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
68
+ self.ckpt_id = ckpt_id
69
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
70
+ self.fp16 = fp16
71
+
72
+ self.use_safetensors = use_safetensors
73
+
74
+ if not self.use_safetensors and not self.fp16:
75
+ raise NotImplementedError(
76
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
77
+ )
78
+
79
+ def run(self):
80
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
81
+ raise ImportError(
82
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
83
+ " installation."
84
+ )
85
+ else:
86
+ from huggingface_hub import create_commit
87
+ from huggingface_hub._commit_api import CommitOperationAdd
88
+
89
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
90
+ with open(model_index, "r") as f:
91
+ pipeline_class_name = json.load(f)["_class_name"]
92
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
93
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
94
+
95
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
96
+ # here, but just to avoid any rough edge cases.
97
+ pipeline = pipeline_class.from_pretrained(
98
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
99
+ )
100
+ pipeline.save_pretrained(
101
+ self.local_ckpt_dir,
102
+ safe_serialization=True if self.use_safetensors else False,
103
+ variant="fp16" if self.fp16 else None,
104
+ )
105
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
106
+
107
+ # Fetch all the paths.
108
+ if self.fp16:
109
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
110
+ elif self.use_safetensors:
111
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
112
+
113
+ # Prepare for the PR.
114
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
115
+ operations = []
116
+ for path in modified_paths:
117
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
118
+
119
+ # Open the PR.
120
+ commit_description = (
121
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
122
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
123
+ )
124
+ hub_pr_url = create_commit(
125
+ repo_id=self.ckpt_id,
126
+ operations=operations,
127
+ commit_message=commit_message,
128
+ commit_description=commit_description,
129
+ repo_type="model",
130
+ create_pr=True,
131
+ ).pr_url
132
+ self.logger.info(f"PR created here: {hub_pr_url}.")
icedit/diffusers/configuration_utils.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ConfigMixin base class and utilities."""
17
+
18
+ import dataclasses
19
+ import functools
20
+ import importlib
21
+ import inspect
22
+ import json
23
+ import os
24
+ import re
25
+ from collections import OrderedDict
26
+ from pathlib import Path
27
+ from typing import Any, Dict, Tuple, Union
28
+
29
+ import numpy as np
30
+ from huggingface_hub import create_repo, hf_hub_download
31
+ from huggingface_hub.utils import (
32
+ EntryNotFoundError,
33
+ RepositoryNotFoundError,
34
+ RevisionNotFoundError,
35
+ validate_hf_hub_args,
36
+ )
37
+ from requests import HTTPError
38
+
39
+ from . import __version__
40
+ from .utils import (
41
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
42
+ DummyObject,
43
+ deprecate,
44
+ extract_commit_hash,
45
+ http_user_agent,
46
+ logging,
47
+ )
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
53
+
54
+
55
+ class FrozenDict(OrderedDict):
56
+ def __init__(self, *args, **kwargs):
57
+ super().__init__(*args, **kwargs)
58
+
59
+ for key, value in self.items():
60
+ setattr(self, key, value)
61
+
62
+ self.__frozen = True
63
+
64
+ def __delitem__(self, *args, **kwargs):
65
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
66
+
67
+ def setdefault(self, *args, **kwargs):
68
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
69
+
70
+ def pop(self, *args, **kwargs):
71
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
72
+
73
+ def update(self, *args, **kwargs):
74
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
75
+
76
+ def __setattr__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setattr__(name, value)
80
+
81
+ def __setitem__(self, name, value):
82
+ if hasattr(self, "__frozen") and self.__frozen:
83
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
84
+ super().__setitem__(name, value)
85
+
86
+
87
+ class ConfigMixin:
88
+ r"""
89
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
90
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
91
+ saving classes that inherit from [`ConfigMixin`].
92
+
93
+ Class attributes:
94
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
95
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
96
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
97
+ overridden by subclass).
98
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
99
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
100
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
101
+ subclass).
102
+ """
103
+
104
+ config_name = None
105
+ ignore_for_config = []
106
+ has_compatibles = False
107
+
108
+ _deprecated_kwargs = []
109
+
110
+ def register_to_config(self, **kwargs):
111
+ if self.config_name is None:
112
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
113
+ # Special case for `kwargs` used in deprecation warning added to schedulers
114
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
115
+ # or solve in a more general way.
116
+ kwargs.pop("kwargs", None)
117
+
118
+ if not hasattr(self, "_internal_dict"):
119
+ internal_dict = kwargs
120
+ else:
121
+ previous_dict = dict(self._internal_dict)
122
+ internal_dict = {**self._internal_dict, **kwargs}
123
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
124
+
125
+ self._internal_dict = FrozenDict(internal_dict)
126
+
127
+ def __getattr__(self, name: str) -> Any:
128
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
129
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
130
+
131
+ This function is mostly copied from PyTorch's __getattr__ overwrite:
132
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
133
+ """
134
+
135
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
136
+ is_attribute = name in self.__dict__
137
+
138
+ if is_in_config and not is_attribute:
139
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
140
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
141
+ return self._internal_dict[name]
142
+
143
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
144
+
145
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
146
+ """
147
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
148
+ [`~ConfigMixin.from_config`] class method.
149
+
150
+ Args:
151
+ save_directory (`str` or `os.PathLike`):
152
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
153
+ push_to_hub (`bool`, *optional*, defaults to `False`):
154
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
155
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
156
+ namespace).
157
+ kwargs (`Dict[str, Any]`, *optional*):
158
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
159
+ """
160
+ if os.path.isfile(save_directory):
161
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
162
+
163
+ os.makedirs(save_directory, exist_ok=True)
164
+
165
+ # If we save using the predefined names, we can load using `from_config`
166
+ output_config_file = os.path.join(save_directory, self.config_name)
167
+
168
+ self.to_json_file(output_config_file)
169
+ logger.info(f"Configuration saved in {output_config_file}")
170
+
171
+ if push_to_hub:
172
+ commit_message = kwargs.pop("commit_message", None)
173
+ private = kwargs.pop("private", None)
174
+ create_pr = kwargs.pop("create_pr", False)
175
+ token = kwargs.pop("token", None)
176
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
177
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
178
+
179
+ self._upload_folder(
180
+ save_directory,
181
+ repo_id,
182
+ token=token,
183
+ commit_message=commit_message,
184
+ create_pr=create_pr,
185
+ )
186
+
187
+ @classmethod
188
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
189
+ r"""
190
+ Instantiate a Python class from a config dictionary.
191
+
192
+ Parameters:
193
+ config (`Dict[str, Any]`):
194
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
195
+ files of compatible classes.
196
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
197
+ Whether kwargs that are not consumed by the Python class should be returned or not.
198
+ kwargs (remaining dictionary of keyword arguments, *optional*):
199
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
200
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
201
+ overwrite the same named arguments in `config`.
202
+
203
+ Returns:
204
+ [`ModelMixin`] or [`SchedulerMixin`]:
205
+ A model or scheduler object instantiated from a config dictionary.
206
+
207
+ Examples:
208
+
209
+ ```python
210
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
211
+
212
+ >>> # Download scheduler from huggingface.co and cache.
213
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
214
+
215
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
216
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
217
+
218
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
219
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
220
+ ```
221
+ """
222
+ # <===== TO BE REMOVED WITH DEPRECATION
223
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
224
+ if "pretrained_model_name_or_path" in kwargs:
225
+ config = kwargs.pop("pretrained_model_name_or_path")
226
+
227
+ if config is None:
228
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
229
+ # ======>
230
+
231
+ if not isinstance(config, dict):
232
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
233
+ if "Scheduler" in cls.__name__:
234
+ deprecation_message += (
235
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
236
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
237
+ " be removed in v1.0.0."
238
+ )
239
+ elif "Model" in cls.__name__:
240
+ deprecation_message += (
241
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
242
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
243
+ " instead. This functionality will be removed in v1.0.0."
244
+ )
245
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
246
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
247
+
248
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
249
+
250
+ # Allow dtype to be specified on initialization
251
+ if "dtype" in unused_kwargs:
252
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
253
+
254
+ # add possible deprecated kwargs
255
+ for deprecated_kwarg in cls._deprecated_kwargs:
256
+ if deprecated_kwarg in unused_kwargs:
257
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
258
+
259
+ # Return model and optionally state and/or unused_kwargs
260
+ model = cls(**init_dict)
261
+
262
+ # make sure to also save config parameters that might be used for compatible classes
263
+ # update _class_name
264
+ if "_class_name" in hidden_dict:
265
+ hidden_dict["_class_name"] = cls.__name__
266
+
267
+ model.register_to_config(**hidden_dict)
268
+
269
+ # add hidden kwargs of compatible classes to unused_kwargs
270
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
271
+
272
+ if return_unused_kwargs:
273
+ return (model, unused_kwargs)
274
+ else:
275
+ return model
276
+
277
+ @classmethod
278
+ def get_config_dict(cls, *args, **kwargs):
279
+ deprecation_message = (
280
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
281
+ " removed in version v1.0.0"
282
+ )
283
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
284
+ return cls.load_config(*args, **kwargs)
285
+
286
+ @classmethod
287
+ @validate_hf_hub_args
288
+ def load_config(
289
+ cls,
290
+ pretrained_model_name_or_path: Union[str, os.PathLike],
291
+ return_unused_kwargs=False,
292
+ return_commit_hash=False,
293
+ **kwargs,
294
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
295
+ r"""
296
+ Load a model or scheduler configuration.
297
+
298
+ Parameters:
299
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
300
+ Can be either:
301
+
302
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
303
+ the Hub.
304
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
305
+ [`~ConfigMixin.save_config`].
306
+
307
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
308
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
309
+ is not used.
310
+ force_download (`bool`, *optional*, defaults to `False`):
311
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
312
+ cached versions if they exist.
313
+ proxies (`Dict[str, str]`, *optional*):
314
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
315
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
316
+ output_loading_info(`bool`, *optional*, defaults to `False`):
317
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
318
+ local_files_only (`bool`, *optional*, defaults to `False`):
319
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
320
+ won't be downloaded from the Hub.
321
+ token (`str` or *bool*, *optional*):
322
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
323
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
324
+ revision (`str`, *optional*, defaults to `"main"`):
325
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
326
+ allowed by Git.
327
+ subfolder (`str`, *optional*, defaults to `""`):
328
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
329
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
330
+ Whether unused keyword arguments of the config are returned.
331
+ return_commit_hash (`bool`, *optional*, defaults to `False):
332
+ Whether the `commit_hash` of the loaded configuration are returned.
333
+
334
+ Returns:
335
+ `dict`:
336
+ A dictionary of all the parameters stored in a JSON configuration file.
337
+
338
+ """
339
+ cache_dir = kwargs.pop("cache_dir", None)
340
+ local_dir = kwargs.pop("local_dir", None)
341
+ local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
342
+ force_download = kwargs.pop("force_download", False)
343
+ proxies = kwargs.pop("proxies", None)
344
+ token = kwargs.pop("token", None)
345
+ local_files_only = kwargs.pop("local_files_only", False)
346
+ revision = kwargs.pop("revision", None)
347
+ _ = kwargs.pop("mirror", None)
348
+ subfolder = kwargs.pop("subfolder", None)
349
+ user_agent = kwargs.pop("user_agent", {})
350
+
351
+ user_agent = {**user_agent, "file_type": "config"}
352
+ user_agent = http_user_agent(user_agent)
353
+
354
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
355
+
356
+ if cls.config_name is None:
357
+ raise ValueError(
358
+ "`self.config_name` is not defined. Note that one should not load a config from "
359
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
360
+ )
361
+
362
+ if os.path.isfile(pretrained_model_name_or_path):
363
+ config_file = pretrained_model_name_or_path
364
+ elif os.path.isdir(pretrained_model_name_or_path):
365
+ if subfolder is not None and os.path.isfile(
366
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
367
+ ):
368
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
369
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
370
+ # Load from a PyTorch checkpoint
371
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
372
+ else:
373
+ raise EnvironmentError(
374
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
375
+ )
376
+ else:
377
+ try:
378
+ # Load from URL or cache if already cached
379
+ config_file = hf_hub_download(
380
+ pretrained_model_name_or_path,
381
+ filename=cls.config_name,
382
+ cache_dir=cache_dir,
383
+ force_download=force_download,
384
+ proxies=proxies,
385
+ local_files_only=local_files_only,
386
+ token=token,
387
+ user_agent=user_agent,
388
+ subfolder=subfolder,
389
+ revision=revision,
390
+ local_dir=local_dir,
391
+ local_dir_use_symlinks=local_dir_use_symlinks,
392
+ )
393
+ except RepositoryNotFoundError:
394
+ raise EnvironmentError(
395
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
396
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
397
+ " token having permission to this repo with `token` or log in with `huggingface-cli login`."
398
+ )
399
+ except RevisionNotFoundError:
400
+ raise EnvironmentError(
401
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
402
+ " this model name. Check the model page at"
403
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
404
+ )
405
+ except EntryNotFoundError:
406
+ raise EnvironmentError(
407
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
408
+ )
409
+ except HTTPError as err:
410
+ raise EnvironmentError(
411
+ "There was a specific connection error when trying to load"
412
+ f" {pretrained_model_name_or_path}:\n{err}"
413
+ )
414
+ except ValueError:
415
+ raise EnvironmentError(
416
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
417
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
418
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
419
+ " run the library in offline mode at"
420
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
421
+ )
422
+ except EnvironmentError:
423
+ raise EnvironmentError(
424
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
425
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
426
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
427
+ f"containing a {cls.config_name} file"
428
+ )
429
+
430
+ try:
431
+ # Load config dict
432
+ config_dict = cls._dict_from_json_file(config_file)
433
+
434
+ commit_hash = extract_commit_hash(config_file)
435
+ except (json.JSONDecodeError, UnicodeDecodeError):
436
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
437
+
438
+ if not (return_unused_kwargs or return_commit_hash):
439
+ return config_dict
440
+
441
+ outputs = (config_dict,)
442
+
443
+ if return_unused_kwargs:
444
+ outputs += (kwargs,)
445
+
446
+ if return_commit_hash:
447
+ outputs += (commit_hash,)
448
+
449
+ return outputs
450
+
451
+ @staticmethod
452
+ def _get_init_keys(input_class):
453
+ return set(dict(inspect.signature(input_class.__init__).parameters).keys())
454
+
455
+ @classmethod
456
+ def extract_init_dict(cls, config_dict, **kwargs):
457
+ # Skip keys that were not present in the original config, so default __init__ values were used
458
+ used_defaults = config_dict.get("_use_default_values", [])
459
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
460
+
461
+ # 0. Copy origin config dict
462
+ original_dict = dict(config_dict.items())
463
+
464
+ # 1. Retrieve expected config attributes from __init__ signature
465
+ expected_keys = cls._get_init_keys(cls)
466
+ expected_keys.remove("self")
467
+ # remove general kwargs if present in dict
468
+ if "kwargs" in expected_keys:
469
+ expected_keys.remove("kwargs")
470
+ # remove flax internal keys
471
+ if hasattr(cls, "_flax_internal_args"):
472
+ for arg in cls._flax_internal_args:
473
+ expected_keys.remove(arg)
474
+
475
+ # 2. Remove attributes that cannot be expected from expected config attributes
476
+ # remove keys to be ignored
477
+ if len(cls.ignore_for_config) > 0:
478
+ expected_keys = expected_keys - set(cls.ignore_for_config)
479
+
480
+ # load diffusers library to import compatible and original scheduler
481
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
482
+
483
+ if cls.has_compatibles:
484
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
485
+ else:
486
+ compatible_classes = []
487
+
488
+ expected_keys_comp_cls = set()
489
+ for c in compatible_classes:
490
+ expected_keys_c = cls._get_init_keys(c)
491
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
492
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
493
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
494
+
495
+ # remove attributes from orig class that cannot be expected
496
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
497
+ if (
498
+ isinstance(orig_cls_name, str)
499
+ and orig_cls_name != cls.__name__
500
+ and hasattr(diffusers_library, orig_cls_name)
501
+ ):
502
+ orig_cls = getattr(diffusers_library, orig_cls_name)
503
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
504
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
505
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
506
+ raise ValueError(
507
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
508
+ )
509
+
510
+ # remove private attributes
511
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
512
+
513
+ # remove quantization_config
514
+ config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"}
515
+
516
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
517
+ init_dict = {}
518
+ for key in expected_keys:
519
+ # if config param is passed to kwarg and is present in config dict
520
+ # it should overwrite existing config dict key
521
+ if key in kwargs and key in config_dict:
522
+ config_dict[key] = kwargs.pop(key)
523
+
524
+ if key in kwargs:
525
+ # overwrite key
526
+ init_dict[key] = kwargs.pop(key)
527
+ elif key in config_dict:
528
+ # use value from config dict
529
+ init_dict[key] = config_dict.pop(key)
530
+
531
+ # 4. Give nice warning if unexpected values have been passed
532
+ if len(config_dict) > 0:
533
+ logger.warning(
534
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
535
+ "but are not expected and will be ignored. Please verify your "
536
+ f"{cls.config_name} configuration file."
537
+ )
538
+
539
+ # 5. Give nice info if config attributes are initialized to default because they have not been passed
540
+ passed_keys = set(init_dict.keys())
541
+ if len(expected_keys - passed_keys) > 0:
542
+ logger.info(
543
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
544
+ )
545
+
546
+ # 6. Define unused keyword arguments
547
+ unused_kwargs = {**config_dict, **kwargs}
548
+
549
+ # 7. Define "hidden" config parameters that were saved for compatible classes
550
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
551
+
552
+ return init_dict, unused_kwargs, hidden_config_dict
553
+
554
+ @classmethod
555
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
556
+ with open(json_file, "r", encoding="utf-8") as reader:
557
+ text = reader.read()
558
+ return json.loads(text)
559
+
560
+ def __repr__(self):
561
+ return f"{self.__class__.__name__} {self.to_json_string()}"
562
+
563
+ @property
564
+ def config(self) -> Dict[str, Any]:
565
+ """
566
+ Returns the config of the class as a frozen dictionary
567
+
568
+ Returns:
569
+ `Dict[str, Any]`: Config of the class.
570
+ """
571
+ return self._internal_dict
572
+
573
+ def to_json_string(self) -> str:
574
+ """
575
+ Serializes the configuration instance to a JSON string.
576
+
577
+ Returns:
578
+ `str`:
579
+ String containing all the attributes that make up the configuration instance in JSON format.
580
+ """
581
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
582
+ config_dict["_class_name"] = self.__class__.__name__
583
+ config_dict["_diffusers_version"] = __version__
584
+
585
+ def to_json_saveable(value):
586
+ if isinstance(value, np.ndarray):
587
+ value = value.tolist()
588
+ elif isinstance(value, Path):
589
+ value = value.as_posix()
590
+ return value
591
+
592
+ if "quantization_config" in config_dict:
593
+ config_dict["quantization_config"] = (
594
+ config_dict.quantization_config.to_dict()
595
+ if not isinstance(config_dict.quantization_config, dict)
596
+ else config_dict.quantization_config
597
+ )
598
+
599
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
600
+ # Don't save "_ignore_files" or "_use_default_values"
601
+ config_dict.pop("_ignore_files", None)
602
+ config_dict.pop("_use_default_values", None)
603
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
604
+ _ = config_dict.pop("_pre_quantization_dtype", None)
605
+
606
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
607
+
608
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
609
+ """
610
+ Save the configuration instance's parameters to a JSON file.
611
+
612
+ Args:
613
+ json_file_path (`str` or `os.PathLike`):
614
+ Path to the JSON file to save a configuration instance's parameters.
615
+ """
616
+ with open(json_file_path, "w", encoding="utf-8") as writer:
617
+ writer.write(self.to_json_string())
618
+
619
+
620
+ def register_to_config(init):
621
+ r"""
622
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
623
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
624
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
625
+
626
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
627
+ """
628
+
629
+ @functools.wraps(init)
630
+ def inner_init(self, *args, **kwargs):
631
+ # Ignore private kwargs in the init.
632
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
633
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
634
+ if not isinstance(self, ConfigMixin):
635
+ raise RuntimeError(
636
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
637
+ "not inherit from `ConfigMixin`."
638
+ )
639
+
640
+ ignore = getattr(self, "ignore_for_config", [])
641
+ # Get positional arguments aligned with kwargs
642
+ new_kwargs = {}
643
+ signature = inspect.signature(init)
644
+ parameters = {
645
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
646
+ }
647
+ for arg, name in zip(args, parameters.keys()):
648
+ new_kwargs[name] = arg
649
+
650
+ # Then add all kwargs
651
+ new_kwargs.update(
652
+ {
653
+ k: init_kwargs.get(k, default)
654
+ for k, default in parameters.items()
655
+ if k not in ignore and k not in new_kwargs
656
+ }
657
+ )
658
+
659
+ # Take note of the parameters that were not present in the loaded config
660
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
661
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
662
+
663
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
664
+ getattr(self, "register_to_config")(**new_kwargs)
665
+ init(self, *args, **init_kwargs)
666
+
667
+ return inner_init
668
+
669
+
670
+ def flax_register_to_config(cls):
671
+ original_init = cls.__init__
672
+
673
+ @functools.wraps(original_init)
674
+ def init(self, *args, **kwargs):
675
+ if not isinstance(self, ConfigMixin):
676
+ raise RuntimeError(
677
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
678
+ "not inherit from `ConfigMixin`."
679
+ )
680
+
681
+ # Ignore private kwargs in the init. Retrieve all passed attributes
682
+ init_kwargs = dict(kwargs.items())
683
+
684
+ # Retrieve default values
685
+ fields = dataclasses.fields(self)
686
+ default_kwargs = {}
687
+ for field in fields:
688
+ # ignore flax specific attributes
689
+ if field.name in self._flax_internal_args:
690
+ continue
691
+ if type(field.default) == dataclasses._MISSING_TYPE:
692
+ default_kwargs[field.name] = None
693
+ else:
694
+ default_kwargs[field.name] = getattr(self, field.name)
695
+
696
+ # Make sure init_kwargs override default kwargs
697
+ new_kwargs = {**default_kwargs, **init_kwargs}
698
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
699
+ if "dtype" in new_kwargs:
700
+ new_kwargs.pop("dtype")
701
+
702
+ # Get positional arguments aligned with kwargs
703
+ for i, arg in enumerate(args):
704
+ name = fields[i].name
705
+ new_kwargs[name] = arg
706
+
707
+ # Take note of the parameters that were not present in the loaded config
708
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
709
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
710
+
711
+ getattr(self, "register_to_config")(**new_kwargs)
712
+ original_init(self, *args, **kwargs)
713
+
714
+ cls.__init__ = init
715
+ return cls
716
+
717
+
718
+ class LegacyConfigMixin(ConfigMixin):
719
+ r"""
720
+ A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
721
+ pipeline-specific classes (like `DiTTransformer2DModel`).
722
+ """
723
+
724
+ @classmethod
725
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
726
+ # To prevent dependency import problem.
727
+ from .models.model_loading_utils import _fetch_remapped_cls_from_config
728
+
729
+ # resolve remapping
730
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
731
+
732
+ return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
icedit/diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .dependency_versions_table import deps
16
+ from .utils.versions import require_version, require_version_core
17
+
18
+
19
+ # define which module versions we always want to check at run time
20
+ # (usually the ones defined in `install_requires` in setup.py)
21
+ #
22
+ # order specific notes:
23
+ # - tqdm must be checked before tokenizers
24
+
25
+ pkgs_to_check_at_runtime = "python requests filelock numpy".split()
26
+ for pkg in pkgs_to_check_at_runtime:
27
+ if pkg in deps:
28
+ require_version_core(deps[pkg])
29
+ else:
30
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
31
+
32
+
33
+ def dep_version_check(pkg, hint=None):
34
+ require_version(deps[pkg], hint)
icedit/diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update`
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.31.0",
7
+ "compel": "compel==0.1.8",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flax": "flax>=0.4.1",
11
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
+ "huggingface-hub": "huggingface-hub>=0.23.2",
13
+ "requests-mock": "requests-mock==1.10.0",
14
+ "importlib_metadata": "importlib_metadata",
15
+ "invisible-watermark": "invisible-watermark>=0.2.0",
16
+ "isort": "isort>=5.5.4",
17
+ "jax": "jax>=0.4.1",
18
+ "jaxlib": "jaxlib>=0.4.1",
19
+ "Jinja2": "Jinja2",
20
+ "k-diffusion": "k-diffusion>=0.0.12",
21
+ "torchsde": "torchsde",
22
+ "note_seq": "note_seq",
23
+ "librosa": "librosa",
24
+ "numpy": "numpy",
25
+ "parameterized": "parameterized",
26
+ "peft": "peft>=0.6.0",
27
+ "protobuf": "protobuf>=3.20.3,<4",
28
+ "pytest": "pytest",
29
+ "pytest-timeout": "pytest-timeout",
30
+ "pytest-xdist": "pytest-xdist",
31
+ "python": "python>=3.8.0",
32
+ "ruff": "ruff==0.1.5",
33
+ "safetensors": "safetensors>=0.3.1",
34
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
+ "GitPython": "GitPython<3.1.19",
36
+ "scipy": "scipy",
37
+ "onnx": "onnx",
38
+ "regex": "regex!=2019.12.17",
39
+ "requests": "requests",
40
+ "tensorboard": "tensorboard",
41
+ "torch": "torch>=1.4",
42
+ "torchvision": "torchvision",
43
+ "transformers": "transformers>=4.41.2",
44
+ "urllib3": "urllib3<=2.0.0",
45
+ "black": "black",
46
+ }
icedit/diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
icedit/diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
icedit/diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unets.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils.dummy_pt_objects import DDPMScheduler
22
+ from ...utils.torch_utils import randn_tensor
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Parameters:
33
+ value_function ([`UNet1DModel`]):
34
+ A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]):
36
+ UNet architecture to denoise the encoded trajectories.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
39
+ application is [`DDPMScheduler`].
40
+ env ():
41
+ An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ value_function: UNet1DModel,
47
+ unet: UNet1DModel,
48
+ scheduler: DDPMScheduler,
49
+ env,
50
+ ):
51
+ super().__init__()
52
+
53
+ self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)
54
+
55
+ self.data = env.get_dataset()
56
+ self.means = {}
57
+ for key in self.data.keys():
58
+ try:
59
+ self.means[key] = self.data[key].mean()
60
+ except: # noqa: E722
61
+ pass
62
+ self.stds = {}
63
+ for key in self.data.keys():
64
+ try:
65
+ self.stds[key] = self.data[key].std()
66
+ except: # noqa: E722
67
+ pass
68
+ self.state_dim = env.observation_space.shape[0]
69
+ self.action_dim = env.action_space.shape[0]
70
+
71
+ def normalize(self, x_in, key):
72
+ return (x_in - self.means[key]) / self.stds[key]
73
+
74
+ def de_normalize(self, x_in, key):
75
+ return x_in * self.stds[key] + self.means[key]
76
+
77
+ def to_torch(self, x_in):
78
+ if isinstance(x_in, dict):
79
+ return {k: self.to_torch(v) for k, v in x_in.items()}
80
+ elif torch.is_tensor(x_in):
81
+ return x_in.to(self.unet.device)
82
+ return torch.tensor(x_in, device=self.unet.device)
83
+
84
+ def reset_x0(self, x_in, cond, act_dim):
85
+ for key, val in cond.items():
86
+ x_in[:, key, act_dim:] = val.clone()
87
+ return x_in
88
+
89
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
90
+ batch_size = x.shape[0]
91
+ y = None
92
+ for i in tqdm.tqdm(self.scheduler.timesteps):
93
+ # create batch of timesteps to pass into model
94
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
95
+ for _ in range(n_guide_steps):
96
+ with torch.enable_grad():
97
+ x.requires_grad_()
98
+
99
+ # permute to match dimension for pre-trained models
100
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
101
+ grad = torch.autograd.grad([y.sum()], [x])[0]
102
+
103
+ posterior_variance = self.scheduler._get_variance(i)
104
+ model_std = torch.exp(0.5 * posterior_variance)
105
+ grad = model_std * grad
106
+
107
+ grad[timesteps < 2] = 0
108
+ x = x.detach()
109
+ x = x + scale * grad
110
+ x = self.reset_x0(x, conditions, self.action_dim)
111
+
112
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
113
+
114
+ # TODO: verify deprecation of this kwarg
115
+ x = self.scheduler.step(prev_x, i, x)["prev_sample"]
116
+
117
+ # apply conditions to the trajectory (set the initial state)
118
+ x = self.reset_x0(x, conditions, self.action_dim)
119
+ x = self.to_torch(x)
120
+ return x, y
121
+
122
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
123
+ # normalize the observations and create batch dimension
124
+ obs = self.normalize(obs, "observations")
125
+ obs = obs[None].repeat(batch_size, axis=0)
126
+
127
+ conditions = {0: self.to_torch(obs)}
128
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
129
+
130
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
131
+ x1 = randn_tensor(shape, device=self.unet.device)
132
+ x = self.reset_x0(x1, conditions, self.action_dim)
133
+ x = self.to_torch(x)
134
+
135
+ # run the diffusion process
136
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
137
+
138
+ # sort output trajectories by value
139
+ sorted_idx = y.argsort(0, descending=True).squeeze()
140
+ sorted_values = x[sorted_idx]
141
+ actions = sorted_values[:, :, : self.action_dim]
142
+ actions = actions.detach().cpu().numpy()
143
+ denorm_actions = self.de_normalize(actions, key="actions")
144
+
145
+ # select the action with the highest value
146
+ if y is not None:
147
+ selected_index = 0
148
+ else:
149
+ # if we didn't run value guiding, select a random action
150
+ selected_index = np.random.randint(0, batch_size)
151
+
152
+ denorm_actions = denorm_actions[selected_index, 0]
153
+ return denorm_actions
icedit/diffusers/image_processor.py ADDED
@@ -0,0 +1,1314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from PIL import Image, ImageFilter, ImageOps
24
+
25
+ from .configuration_utils import ConfigMixin, register_to_config
26
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
27
+
28
+
29
+ PipelineImageInput = Union[
30
+ PIL.Image.Image,
31
+ np.ndarray,
32
+ torch.Tensor,
33
+ List[PIL.Image.Image],
34
+ List[np.ndarray],
35
+ List[torch.Tensor],
36
+ ]
37
+
38
+ PipelineDepthInput = PipelineImageInput
39
+
40
+
41
+ def is_valid_image(image) -> bool:
42
+ r"""
43
+ Checks if the input is a valid image.
44
+
45
+ A valid image can be:
46
+ - A `PIL.Image.Image`.
47
+ - A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
48
+
49
+ Args:
50
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
51
+ The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
52
+
53
+ Returns:
54
+ `bool`:
55
+ `True` if the input is a valid image, `False` otherwise.
56
+ """
57
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
58
+
59
+
60
+ def is_valid_image_imagelist(images):
61
+ r"""
62
+ Checks if the input is a valid image or list of images.
63
+
64
+ The input can be one of the following formats:
65
+ - A 4D tensor or numpy array (batch of images).
66
+ - A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
67
+ `torch.Tensor`.
68
+ - A list of valid images.
69
+
70
+ Args:
71
+ images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
72
+ The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
73
+ images.
74
+
75
+ Returns:
76
+ `bool`:
77
+ `True` if the input is valid, `False` otherwise.
78
+ """
79
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
80
+ return True
81
+ elif is_valid_image(images):
82
+ return True
83
+ elif isinstance(images, list):
84
+ return all(is_valid_image(image) for image in images)
85
+ return False
86
+
87
+
88
+ class VaeImageProcessor(ConfigMixin):
89
+ """
90
+ Image processor for VAE.
91
+
92
+ Args:
93
+ do_resize (`bool`, *optional*, defaults to `True`):
94
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
95
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
96
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
97
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
98
+ resample (`str`, *optional*, defaults to `lanczos`):
99
+ Resampling filter to use when resizing the image.
100
+ do_normalize (`bool`, *optional*, defaults to `True`):
101
+ Whether to normalize the image to [-1,1].
102
+ do_binarize (`bool`, *optional*, defaults to `False`):
103
+ Whether to binarize the image to 0/1.
104
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
105
+ Whether to convert the images to RGB format.
106
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
107
+ Whether to convert the images to grayscale format.
108
+ """
109
+
110
+ config_name = CONFIG_NAME
111
+
112
+ @register_to_config
113
+ def __init__(
114
+ self,
115
+ do_resize: bool = True,
116
+ vae_scale_factor: int = 8,
117
+ vae_latent_channels: int = 4,
118
+ resample: str = "lanczos",
119
+ do_normalize: bool = True,
120
+ do_binarize: bool = False,
121
+ do_convert_rgb: bool = False,
122
+ do_convert_grayscale: bool = False,
123
+ ):
124
+ super().__init__()
125
+ if do_convert_rgb and do_convert_grayscale:
126
+ raise ValueError(
127
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
128
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
129
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
130
+ )
131
+
132
+ @staticmethod
133
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
134
+ r"""
135
+ Convert a numpy image or a batch of images to a PIL image.
136
+
137
+ Args:
138
+ images (`np.ndarray`):
139
+ The image array to convert to PIL format.
140
+
141
+ Returns:
142
+ `List[PIL.Image.Image]`:
143
+ A list of PIL images.
144
+ """
145
+ if images.ndim == 3:
146
+ images = images[None, ...]
147
+ images = (images * 255).round().astype("uint8")
148
+ if images.shape[-1] == 1:
149
+ # special case for grayscale (single channel) images
150
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
151
+ else:
152
+ pil_images = [Image.fromarray(image) for image in images]
153
+
154
+ return pil_images
155
+
156
+ @staticmethod
157
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
158
+ r"""
159
+ Convert a PIL image or a list of PIL images to NumPy arrays.
160
+
161
+ Args:
162
+ images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
163
+ The PIL image or list of images to convert to NumPy format.
164
+
165
+ Returns:
166
+ `np.ndarray`:
167
+ A NumPy array representation of the images.
168
+ """
169
+ if not isinstance(images, list):
170
+ images = [images]
171
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
172
+ images = np.stack(images, axis=0)
173
+
174
+ return images
175
+
176
+ @staticmethod
177
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
178
+ r"""
179
+ Convert a NumPy image to a PyTorch tensor.
180
+
181
+ Args:
182
+ images (`np.ndarray`):
183
+ The NumPy image array to convert to PyTorch format.
184
+
185
+ Returns:
186
+ `torch.Tensor`:
187
+ A PyTorch tensor representation of the images.
188
+ """
189
+ if images.ndim == 3:
190
+ images = images[..., None]
191
+
192
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
193
+ return images
194
+
195
+ @staticmethod
196
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
197
+ r"""
198
+ Convert a PyTorch tensor to a NumPy image.
199
+
200
+ Args:
201
+ images (`torch.Tensor`):
202
+ The PyTorch tensor to convert to NumPy format.
203
+
204
+ Returns:
205
+ `np.ndarray`:
206
+ A NumPy array representation of the images.
207
+ """
208
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
209
+ return images
210
+
211
+ @staticmethod
212
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
213
+ r"""
214
+ Normalize an image array to [-1,1].
215
+
216
+ Args:
217
+ images (`np.ndarray` or `torch.Tensor`):
218
+ The image array to normalize.
219
+
220
+ Returns:
221
+ `np.ndarray` or `torch.Tensor`:
222
+ The normalized image array.
223
+ """
224
+ return 2.0 * images - 1.0
225
+
226
+ @staticmethod
227
+ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
228
+ r"""
229
+ Denormalize an image array to [0,1].
230
+
231
+ Args:
232
+ images (`np.ndarray` or `torch.Tensor`):
233
+ The image array to denormalize.
234
+
235
+ Returns:
236
+ `np.ndarray` or `torch.Tensor`:
237
+ The denormalized image array.
238
+ """
239
+ return (images * 0.5 + 0.5).clamp(0, 1)
240
+
241
+ @staticmethod
242
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
243
+ r"""
244
+ Converts a PIL image to RGB format.
245
+
246
+ Args:
247
+ image (`PIL.Image.Image`):
248
+ The PIL image to convert to RGB.
249
+
250
+ Returns:
251
+ `PIL.Image.Image`:
252
+ The RGB-converted PIL image.
253
+ """
254
+ image = image.convert("RGB")
255
+
256
+ return image
257
+
258
+ @staticmethod
259
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
260
+ r"""
261
+ Converts a given PIL image to grayscale.
262
+
263
+ Args:
264
+ image (`PIL.Image.Image`):
265
+ The input image to convert.
266
+
267
+ Returns:
268
+ `PIL.Image.Image`:
269
+ The image converted to grayscale.
270
+ """
271
+ image = image.convert("L")
272
+
273
+ return image
274
+
275
+ @staticmethod
276
+ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
277
+ r"""
278
+ Applies Gaussian blur to an image.
279
+
280
+ Args:
281
+ image (`PIL.Image.Image`):
282
+ The PIL image to convert to grayscale.
283
+
284
+ Returns:
285
+ `PIL.Image.Image`:
286
+ The grayscale-converted PIL image.
287
+ """
288
+ image = image.filter(ImageFilter.GaussianBlur(blur_factor))
289
+
290
+ return image
291
+
292
+ @staticmethod
293
+ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
294
+ r"""
295
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
296
+ ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
297
+ processing are 512x512, the region will be expanded to 128x128.
298
+
299
+ Args:
300
+ mask_image (PIL.Image.Image): Mask image.
301
+ width (int): Width of the image to be processed.
302
+ height (int): Height of the image to be processed.
303
+ pad (int, optional): Padding to be added to the crop region. Defaults to 0.
304
+
305
+ Returns:
306
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
307
+ matches the original aspect ratio.
308
+ """
309
+
310
+ mask_image = mask_image.convert("L")
311
+ mask = np.array(mask_image)
312
+
313
+ # 1. find a rectangular region that contains all masked ares in an image
314
+ h, w = mask.shape
315
+ crop_left = 0
316
+ for i in range(w):
317
+ if not (mask[:, i] == 0).all():
318
+ break
319
+ crop_left += 1
320
+
321
+ crop_right = 0
322
+ for i in reversed(range(w)):
323
+ if not (mask[:, i] == 0).all():
324
+ break
325
+ crop_right += 1
326
+
327
+ crop_top = 0
328
+ for i in range(h):
329
+ if not (mask[i] == 0).all():
330
+ break
331
+ crop_top += 1
332
+
333
+ crop_bottom = 0
334
+ for i in reversed(range(h)):
335
+ if not (mask[i] == 0).all():
336
+ break
337
+ crop_bottom += 1
338
+
339
+ # 2. add padding to the crop region
340
+ x1, y1, x2, y2 = (
341
+ int(max(crop_left - pad, 0)),
342
+ int(max(crop_top - pad, 0)),
343
+ int(min(w - crop_right + pad, w)),
344
+ int(min(h - crop_bottom + pad, h)),
345
+ )
346
+
347
+ # 3. expands crop region to match the aspect ratio of the image to be processed
348
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
349
+ ratio_processing = width / height
350
+
351
+ if ratio_crop_region > ratio_processing:
352
+ desired_height = (x2 - x1) / ratio_processing
353
+ desired_height_diff = int(desired_height - (y2 - y1))
354
+ y1 -= desired_height_diff // 2
355
+ y2 += desired_height_diff - desired_height_diff // 2
356
+ if y2 >= mask_image.height:
357
+ diff = y2 - mask_image.height
358
+ y2 -= diff
359
+ y1 -= diff
360
+ if y1 < 0:
361
+ y2 -= y1
362
+ y1 -= y1
363
+ if y2 >= mask_image.height:
364
+ y2 = mask_image.height
365
+ else:
366
+ desired_width = (y2 - y1) * ratio_processing
367
+ desired_width_diff = int(desired_width - (x2 - x1))
368
+ x1 -= desired_width_diff // 2
369
+ x2 += desired_width_diff - desired_width_diff // 2
370
+ if x2 >= mask_image.width:
371
+ diff = x2 - mask_image.width
372
+ x2 -= diff
373
+ x1 -= diff
374
+ if x1 < 0:
375
+ x2 -= x1
376
+ x1 -= x1
377
+ if x2 >= mask_image.width:
378
+ x2 = mask_image.width
379
+
380
+ return x1, y1, x2, y2
381
+
382
+ def _resize_and_fill(
383
+ self,
384
+ image: PIL.Image.Image,
385
+ width: int,
386
+ height: int,
387
+ ) -> PIL.Image.Image:
388
+ r"""
389
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
390
+ the image within the dimensions, filling empty with data from image.
391
+
392
+ Args:
393
+ image (`PIL.Image.Image`):
394
+ The image to resize and fill.
395
+ width (`int`):
396
+ The width to resize the image to.
397
+ height (`int`):
398
+ The height to resize the image to.
399
+
400
+ Returns:
401
+ `PIL.Image.Image`:
402
+ The resized and filled image.
403
+ """
404
+
405
+ ratio = width / height
406
+ src_ratio = image.width / image.height
407
+
408
+ src_w = width if ratio < src_ratio else image.width * height // image.height
409
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
410
+
411
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
412
+ res = Image.new("RGB", (width, height))
413
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
414
+
415
+ if ratio < src_ratio:
416
+ fill_height = height // 2 - src_h // 2
417
+ if fill_height > 0:
418
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
419
+ res.paste(
420
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
421
+ box=(0, fill_height + src_h),
422
+ )
423
+ elif ratio > src_ratio:
424
+ fill_width = width // 2 - src_w // 2
425
+ if fill_width > 0:
426
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
427
+ res.paste(
428
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
429
+ box=(fill_width + src_w, 0),
430
+ )
431
+
432
+ return res
433
+
434
+ def _resize_and_crop(
435
+ self,
436
+ image: PIL.Image.Image,
437
+ width: int,
438
+ height: int,
439
+ ) -> PIL.Image.Image:
440
+ r"""
441
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
442
+ the image within the dimensions, cropping the excess.
443
+
444
+ Args:
445
+ image (`PIL.Image.Image`):
446
+ The image to resize and crop.
447
+ width (`int`):
448
+ The width to resize the image to.
449
+ height (`int`):
450
+ The height to resize the image to.
451
+
452
+ Returns:
453
+ `PIL.Image.Image`:
454
+ The resized and cropped image.
455
+ """
456
+ ratio = width / height
457
+ src_ratio = image.width / image.height
458
+
459
+ src_w = width if ratio > src_ratio else image.width * height // image.height
460
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
461
+
462
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
463
+ res = Image.new("RGB", (width, height))
464
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
465
+ return res
466
+
467
+ def resize(
468
+ self,
469
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
470
+ height: int,
471
+ width: int,
472
+ resize_mode: str = "default", # "default", "fill", "crop"
473
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
474
+ """
475
+ Resize image.
476
+
477
+ Args:
478
+ image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
479
+ The image input, can be a PIL image, numpy array or pytorch tensor.
480
+ height (`int`):
481
+ The height to resize to.
482
+ width (`int`):
483
+ The width to resize to.
484
+ resize_mode (`str`, *optional*, defaults to `default`):
485
+ The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
486
+ within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
487
+ will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
488
+ then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
489
+ the image to fit within the specified width and height, maintaining the aspect ratio, and then center
490
+ the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
491
+ supported for PIL image input.
492
+
493
+ Returns:
494
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
495
+ The resized image.
496
+ """
497
+ if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
498
+ raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
499
+ if isinstance(image, PIL.Image.Image):
500
+ if resize_mode == "default":
501
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
502
+ elif resize_mode == "fill":
503
+ image = self._resize_and_fill(image, width, height)
504
+ elif resize_mode == "crop":
505
+ image = self._resize_and_crop(image, width, height)
506
+ else:
507
+ raise ValueError(f"resize_mode {resize_mode} is not supported")
508
+
509
+ elif isinstance(image, torch.Tensor):
510
+ image = torch.nn.functional.interpolate(
511
+ image,
512
+ size=(height, width),
513
+ )
514
+ elif isinstance(image, np.ndarray):
515
+ image = self.numpy_to_pt(image)
516
+ image = torch.nn.functional.interpolate(
517
+ image,
518
+ size=(height, width),
519
+ )
520
+ image = self.pt_to_numpy(image)
521
+ return image
522
+
523
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
524
+ """
525
+ Create a mask.
526
+
527
+ Args:
528
+ image (`PIL.Image.Image`):
529
+ The image input, should be a PIL image.
530
+
531
+ Returns:
532
+ `PIL.Image.Image`:
533
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
534
+ """
535
+ image[image < 0.5] = 0
536
+ image[image >= 0.5] = 1
537
+
538
+ return image
539
+
540
+ def _denormalize_conditionally(
541
+ self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
542
+ ) -> torch.Tensor:
543
+ r"""
544
+ Denormalize a batch of images based on a condition list.
545
+
546
+ Args:
547
+ images (`torch.Tensor`):
548
+ The input image tensor.
549
+ do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
550
+ A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
551
+ value of `do_normalize` in the `VaeImageProcessor` config.
552
+ """
553
+ if do_denormalize is None:
554
+ return self.denormalize(images) if self.config.do_normalize else images
555
+
556
+ return torch.stack(
557
+ [self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
558
+ )
559
+
560
+ def get_default_height_width(
561
+ self,
562
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
563
+ height: Optional[int] = None,
564
+ width: Optional[int] = None,
565
+ ) -> Tuple[int, int]:
566
+ r"""
567
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
568
+
569
+ Args:
570
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
571
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
572
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
573
+ tensor, it should have shape `[batch, channels, height, width]`.
574
+ height (`Optional[int]`, *optional*, defaults to `None`):
575
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
576
+ width (`Optional[int]`, *optional*, defaults to `None`):
577
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
578
+
579
+ Returns:
580
+ `Tuple[int, int]`:
581
+ A tuple containing the height and width, both resized to the nearest integer multiple of
582
+ `vae_scale_factor`.
583
+ """
584
+
585
+ if height is None:
586
+ if isinstance(image, PIL.Image.Image):
587
+ height = image.height
588
+ elif isinstance(image, torch.Tensor):
589
+ height = image.shape[2]
590
+ else:
591
+ height = image.shape[1]
592
+
593
+ if width is None:
594
+ if isinstance(image, PIL.Image.Image):
595
+ width = image.width
596
+ elif isinstance(image, torch.Tensor):
597
+ width = image.shape[3]
598
+ else:
599
+ width = image.shape[2]
600
+
601
+ width, height = (
602
+ x - x % self.config.vae_scale_factor for x in (width, height)
603
+ ) # resize to integer multiple of vae_scale_factor
604
+
605
+ return height, width
606
+
607
+ def preprocess(
608
+ self,
609
+ image: PipelineImageInput,
610
+ height: Optional[int] = None,
611
+ width: Optional[int] = None,
612
+ resize_mode: str = "default", # "default", "fill", "crop"
613
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
614
+ ) -> torch.Tensor:
615
+ """
616
+ Preprocess the image input.
617
+
618
+ Args:
619
+ image (`PipelineImageInput`):
620
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
621
+ supported formats.
622
+ height (`int`, *optional*):
623
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
624
+ height.
625
+ width (`int`, *optional*):
626
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
627
+ resize_mode (`str`, *optional*, defaults to `default`):
628
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
629
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
630
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
631
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
632
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
633
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
634
+ supported for PIL image input.
635
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
636
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
637
+
638
+ Returns:
639
+ `torch.Tensor`:
640
+ The preprocessed image.
641
+ """
642
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
643
+
644
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
645
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
646
+ if isinstance(image, torch.Tensor):
647
+ # if image is a pytorch tensor could have 2 possible shapes:
648
+ # 1. batch x height x width: we should insert the channel dimension at position 1
649
+ # 2. channel x height x width: we should insert batch dimension at position 0,
650
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
651
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
652
+ image = image.unsqueeze(1)
653
+ else:
654
+ # if it is a numpy array, it could have 2 possible shapes:
655
+ # 1. batch x height x width: insert channel dimension on last position
656
+ # 2. height x width x channel: insert batch dimension on first position
657
+ if image.shape[-1] == 1:
658
+ image = np.expand_dims(image, axis=0)
659
+ else:
660
+ image = np.expand_dims(image, axis=-1)
661
+
662
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
663
+ warnings.warn(
664
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
665
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
666
+ FutureWarning,
667
+ )
668
+ image = np.concatenate(image, axis=0)
669
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
670
+ warnings.warn(
671
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
672
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
673
+ FutureWarning,
674
+ )
675
+ image = torch.cat(image, axis=0)
676
+
677
+ if not is_valid_image_imagelist(image):
678
+ raise ValueError(
679
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
680
+ )
681
+ if not isinstance(image, list):
682
+ image = [image]
683
+
684
+ if isinstance(image[0], PIL.Image.Image):
685
+ if crops_coords is not None:
686
+ image = [i.crop(crops_coords) for i in image]
687
+ if self.config.do_resize:
688
+ height, width = self.get_default_height_width(image[0], height, width)
689
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
690
+ if self.config.do_convert_rgb:
691
+ image = [self.convert_to_rgb(i) for i in image]
692
+ elif self.config.do_convert_grayscale:
693
+ image = [self.convert_to_grayscale(i) for i in image]
694
+ image = self.pil_to_numpy(image) # to np
695
+ image = self.numpy_to_pt(image) # to pt
696
+
697
+ elif isinstance(image[0], np.ndarray):
698
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
699
+
700
+ image = self.numpy_to_pt(image)
701
+
702
+ height, width = self.get_default_height_width(image, height, width)
703
+ if self.config.do_resize:
704
+ image = self.resize(image, height, width)
705
+
706
+ elif isinstance(image[0], torch.Tensor):
707
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
708
+
709
+ if self.config.do_convert_grayscale and image.ndim == 3:
710
+ image = image.unsqueeze(1)
711
+
712
+ channel = image.shape[1]
713
+ # don't need any preprocess if the image is latents
714
+ if channel == self.config.vae_latent_channels:
715
+ return image
716
+
717
+ height, width = self.get_default_height_width(image, height, width)
718
+ if self.config.do_resize:
719
+ image = self.resize(image, height, width)
720
+
721
+ # expected range [0,1], normalize to [-1,1]
722
+ do_normalize = self.config.do_normalize
723
+ if do_normalize and image.min() < 0:
724
+ warnings.warn(
725
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
726
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
727
+ FutureWarning,
728
+ )
729
+ do_normalize = False
730
+ if do_normalize:
731
+ image = self.normalize(image)
732
+
733
+ if self.config.do_binarize:
734
+ image = self.binarize(image)
735
+
736
+ return image
737
+
738
+ def postprocess(
739
+ self,
740
+ image: torch.Tensor,
741
+ output_type: str = "pil",
742
+ do_denormalize: Optional[List[bool]] = None,
743
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
744
+ """
745
+ Postprocess the image output from tensor to `output_type`.
746
+
747
+ Args:
748
+ image (`torch.Tensor`):
749
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
750
+ output_type (`str`, *optional*, defaults to `pil`):
751
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
752
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
753
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
754
+ `VaeImageProcessor` config.
755
+
756
+ Returns:
757
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
758
+ The postprocessed image.
759
+ """
760
+ if not isinstance(image, torch.Tensor):
761
+ raise ValueError(
762
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
763
+ )
764
+ if output_type not in ["latent", "pt", "np", "pil"]:
765
+ deprecation_message = (
766
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
767
+ "`pil`, `np`, `pt`, `latent`"
768
+ )
769
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
770
+ output_type = "np"
771
+
772
+ if output_type == "latent":
773
+ return image
774
+
775
+ image = self._denormalize_conditionally(image, do_denormalize)
776
+
777
+ if output_type == "pt":
778
+ return image
779
+
780
+ image = self.pt_to_numpy(image)
781
+
782
+ if output_type == "np":
783
+ return image
784
+
785
+ if output_type == "pil":
786
+ return self.numpy_to_pil(image)
787
+
788
+ def apply_overlay(
789
+ self,
790
+ mask: PIL.Image.Image,
791
+ init_image: PIL.Image.Image,
792
+ image: PIL.Image.Image,
793
+ crop_coords: Optional[Tuple[int, int, int, int]] = None,
794
+ ) -> PIL.Image.Image:
795
+ r"""
796
+ Applies an overlay of the mask and the inpainted image on the original image.
797
+
798
+ Args:
799
+ mask (`PIL.Image.Image`):
800
+ The mask image that highlights regions to overlay.
801
+ init_image (`PIL.Image.Image`):
802
+ The original image to which the overlay is applied.
803
+ image (`PIL.Image.Image`):
804
+ The image to overlay onto the original.
805
+ crop_coords (`Tuple[int, int, int, int]`, *optional*):
806
+ Coordinates to crop the image. If provided, the image will be cropped accordingly.
807
+
808
+ Returns:
809
+ `PIL.Image.Image`:
810
+ The final image with the overlay applied.
811
+ """
812
+
813
+ width, height = init_image.width, init_image.height
814
+
815
+ init_image_masked = PIL.Image.new("RGBa", (width, height))
816
+ init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
817
+
818
+ init_image_masked = init_image_masked.convert("RGBA")
819
+
820
+ if crop_coords is not None:
821
+ x, y, x2, y2 = crop_coords
822
+ w = x2 - x
823
+ h = y2 - y
824
+ base_image = PIL.Image.new("RGBA", (width, height))
825
+ image = self.resize(image, height=h, width=w, resize_mode="crop")
826
+ base_image.paste(image, (x, y))
827
+ image = base_image.convert("RGB")
828
+
829
+ image = image.convert("RGBA")
830
+ image.alpha_composite(init_image_masked)
831
+ image = image.convert("RGB")
832
+
833
+ return image
834
+
835
+
836
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
837
+ """
838
+ Image processor for VAE LDM3D.
839
+
840
+ Args:
841
+ do_resize (`bool`, *optional*, defaults to `True`):
842
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
843
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
844
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
845
+ resample (`str`, *optional*, defaults to `lanczos`):
846
+ Resampling filter to use when resizing the image.
847
+ do_normalize (`bool`, *optional*, defaults to `True`):
848
+ Whether to normalize the image to [-1,1].
849
+ """
850
+
851
+ config_name = CONFIG_NAME
852
+
853
+ @register_to_config
854
+ def __init__(
855
+ self,
856
+ do_resize: bool = True,
857
+ vae_scale_factor: int = 8,
858
+ resample: str = "lanczos",
859
+ do_normalize: bool = True,
860
+ ):
861
+ super().__init__()
862
+
863
+ @staticmethod
864
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
865
+ r"""
866
+ Convert a NumPy image or a batch of images to a list of PIL images.
867
+
868
+ Args:
869
+ images (`np.ndarray`):
870
+ The input NumPy array of images, which can be a single image or a batch.
871
+
872
+ Returns:
873
+ `List[PIL.Image.Image]`:
874
+ A list of PIL images converted from the input NumPy array.
875
+ """
876
+ if images.ndim == 3:
877
+ images = images[None, ...]
878
+ images = (images * 255).round().astype("uint8")
879
+ if images.shape[-1] == 1:
880
+ # special case for grayscale (single channel) images
881
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
882
+ else:
883
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
884
+
885
+ return pil_images
886
+
887
+ @staticmethod
888
+ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
889
+ r"""
890
+ Convert a PIL image or a list of PIL images to NumPy arrays.
891
+
892
+ Args:
893
+ images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
894
+ The input image or list of images to be converted.
895
+
896
+ Returns:
897
+ `np.ndarray`:
898
+ A NumPy array of the converted images.
899
+ """
900
+ if not isinstance(images, list):
901
+ images = [images]
902
+
903
+ images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
904
+ images = np.stack(images, axis=0)
905
+ return images
906
+
907
+ @staticmethod
908
+ def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
909
+ r"""
910
+ Convert an RGB-like depth image to a depth map.
911
+
912
+ Args:
913
+ image (`Union[np.ndarray, torch.Tensor]`):
914
+ The RGB-like depth image to convert.
915
+
916
+ Returns:
917
+ `Union[np.ndarray, torch.Tensor]`:
918
+ The corresponding depth map.
919
+ """
920
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
921
+
922
+ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
923
+ r"""
924
+ Convert a NumPy depth image or a batch of images to a list of PIL images.
925
+
926
+ Args:
927
+ images (`np.ndarray`):
928
+ The input NumPy array of depth images, which can be a single image or a batch.
929
+
930
+ Returns:
931
+ `List[PIL.Image.Image]`:
932
+ A list of PIL images converted from the input NumPy depth images.
933
+ """
934
+ if images.ndim == 3:
935
+ images = images[None, ...]
936
+ images_depth = images[:, :, :, 3:]
937
+ if images.shape[-1] == 6:
938
+ images_depth = (images_depth * 255).round().astype("uint8")
939
+ pil_images = [
940
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
941
+ ]
942
+ elif images.shape[-1] == 4:
943
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
944
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
945
+ else:
946
+ raise Exception("Not supported")
947
+
948
+ return pil_images
949
+
950
+ def postprocess(
951
+ self,
952
+ image: torch.Tensor,
953
+ output_type: str = "pil",
954
+ do_denormalize: Optional[List[bool]] = None,
955
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
956
+ """
957
+ Postprocess the image output from tensor to `output_type`.
958
+
959
+ Args:
960
+ image (`torch.Tensor`):
961
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
962
+ output_type (`str`, *optional*, defaults to `pil`):
963
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
964
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
965
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
966
+ `VaeImageProcessor` config.
967
+
968
+ Returns:
969
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
970
+ The postprocessed image.
971
+ """
972
+ if not isinstance(image, torch.Tensor):
973
+ raise ValueError(
974
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
975
+ )
976
+ if output_type not in ["latent", "pt", "np", "pil"]:
977
+ deprecation_message = (
978
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
979
+ "`pil`, `np`, `pt`, `latent`"
980
+ )
981
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
982
+ output_type = "np"
983
+
984
+ image = self._denormalize_conditionally(image, do_denormalize)
985
+
986
+ image = self.pt_to_numpy(image)
987
+
988
+ if output_type == "np":
989
+ if image.shape[-1] == 6:
990
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
991
+ else:
992
+ image_depth = image[:, :, :, 3:]
993
+ return image[:, :, :, :3], image_depth
994
+
995
+ if output_type == "pil":
996
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
997
+ else:
998
+ raise Exception(f"This type {output_type} is not supported")
999
+
1000
+ def preprocess(
1001
+ self,
1002
+ rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
1003
+ depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
1004
+ height: Optional[int] = None,
1005
+ width: Optional[int] = None,
1006
+ target_res: Optional[int] = None,
1007
+ ) -> torch.Tensor:
1008
+ r"""
1009
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
1010
+
1011
+ Args:
1012
+ rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
1013
+ The RGB input image, which can be a single image or a batch.
1014
+ depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
1015
+ The depth input image, which can be a single image or a batch.
1016
+ height (`Optional[int]`, *optional*, defaults to `None`):
1017
+ The desired height of the processed image. If `None`, defaults to the height of the input image.
1018
+ width (`Optional[int]`, *optional*, defaults to `None`):
1019
+ The desired width of the processed image. If `None`, defaults to the width of the input image.
1020
+ target_res (`Optional[int]`, *optional*, defaults to `None`):
1021
+ Target resolution for resizing the images. If specified, overrides height and width.
1022
+
1023
+ Returns:
1024
+ `Tuple[torch.Tensor, torch.Tensor]`:
1025
+ A tuple containing the processed RGB and depth images as PyTorch tensors.
1026
+ """
1027
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
1028
+
1029
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
1030
+ if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
1031
+ raise Exception("This is not yet supported")
1032
+
1033
+ if isinstance(rgb, supported_formats):
1034
+ rgb = [rgb]
1035
+ depth = [depth]
1036
+ elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
1037
+ raise ValueError(
1038
+ f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
1039
+ )
1040
+
1041
+ if isinstance(rgb[0], PIL.Image.Image):
1042
+ if self.config.do_convert_rgb:
1043
+ raise Exception("This is not yet supported")
1044
+ # rgb = [self.convert_to_rgb(i) for i in rgb]
1045
+ # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
1046
+ if self.config.do_resize or target_res:
1047
+ height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
1048
+ rgb = [self.resize(i, height, width) for i in rgb]
1049
+ depth = [self.resize(i, height, width) for i in depth]
1050
+ rgb = self.pil_to_numpy(rgb) # to np
1051
+ rgb = self.numpy_to_pt(rgb) # to pt
1052
+
1053
+ depth = self.depth_pil_to_numpy(depth) # to np
1054
+ depth = self.numpy_to_pt(depth) # to pt
1055
+
1056
+ elif isinstance(rgb[0], np.ndarray):
1057
+ rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
1058
+ rgb = self.numpy_to_pt(rgb)
1059
+ height, width = self.get_default_height_width(rgb, height, width)
1060
+ if self.config.do_resize:
1061
+ rgb = self.resize(rgb, height, width)
1062
+
1063
+ depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
1064
+ depth = self.numpy_to_pt(depth)
1065
+ height, width = self.get_default_height_width(depth, height, width)
1066
+ if self.config.do_resize:
1067
+ depth = self.resize(depth, height, width)
1068
+
1069
+ elif isinstance(rgb[0], torch.Tensor):
1070
+ raise Exception("This is not yet supported")
1071
+ # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
1072
+
1073
+ # if self.config.do_convert_grayscale and rgb.ndim == 3:
1074
+ # rgb = rgb.unsqueeze(1)
1075
+
1076
+ # channel = rgb.shape[1]
1077
+
1078
+ # height, width = self.get_default_height_width(rgb, height, width)
1079
+ # if self.config.do_resize:
1080
+ # rgb = self.resize(rgb, height, width)
1081
+
1082
+ # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
1083
+
1084
+ # if self.config.do_convert_grayscale and depth.ndim == 3:
1085
+ # depth = depth.unsqueeze(1)
1086
+
1087
+ # channel = depth.shape[1]
1088
+ # # don't need any preprocess if the image is latents
1089
+ # if depth == 4:
1090
+ # return rgb, depth
1091
+
1092
+ # height, width = self.get_default_height_width(depth, height, width)
1093
+ # if self.config.do_resize:
1094
+ # depth = self.resize(depth, height, width)
1095
+ # expected range [0,1], normalize to [-1,1]
1096
+ do_normalize = self.config.do_normalize
1097
+ if rgb.min() < 0 and do_normalize:
1098
+ warnings.warn(
1099
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
1100
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
1101
+ FutureWarning,
1102
+ )
1103
+ do_normalize = False
1104
+
1105
+ if do_normalize:
1106
+ rgb = self.normalize(rgb)
1107
+ depth = self.normalize(depth)
1108
+
1109
+ if self.config.do_binarize:
1110
+ rgb = self.binarize(rgb)
1111
+ depth = self.binarize(depth)
1112
+
1113
+ return rgb, depth
1114
+
1115
+
1116
+ class IPAdapterMaskProcessor(VaeImageProcessor):
1117
+ """
1118
+ Image processor for IP Adapter image masks.
1119
+
1120
+ Args:
1121
+ do_resize (`bool`, *optional*, defaults to `True`):
1122
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
1123
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
1124
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
1125
+ resample (`str`, *optional*, defaults to `lanczos`):
1126
+ Resampling filter to use when resizing the image.
1127
+ do_normalize (`bool`, *optional*, defaults to `False`):
1128
+ Whether to normalize the image to [-1,1].
1129
+ do_binarize (`bool`, *optional*, defaults to `True`):
1130
+ Whether to binarize the image to 0/1.
1131
+ do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
1132
+ Whether to convert the images to grayscale format.
1133
+
1134
+ """
1135
+
1136
+ config_name = CONFIG_NAME
1137
+
1138
+ @register_to_config
1139
+ def __init__(
1140
+ self,
1141
+ do_resize: bool = True,
1142
+ vae_scale_factor: int = 8,
1143
+ resample: str = "lanczos",
1144
+ do_normalize: bool = False,
1145
+ do_binarize: bool = True,
1146
+ do_convert_grayscale: bool = True,
1147
+ ):
1148
+ super().__init__(
1149
+ do_resize=do_resize,
1150
+ vae_scale_factor=vae_scale_factor,
1151
+ resample=resample,
1152
+ do_normalize=do_normalize,
1153
+ do_binarize=do_binarize,
1154
+ do_convert_grayscale=do_convert_grayscale,
1155
+ )
1156
+
1157
+ @staticmethod
1158
+ def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
1159
+ """
1160
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
1161
+ aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
1162
+
1163
+ Args:
1164
+ mask (`torch.Tensor`):
1165
+ The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
1166
+ batch_size (`int`):
1167
+ The batch size.
1168
+ num_queries (`int`):
1169
+ The number of queries.
1170
+ value_embed_dim (`int`):
1171
+ The dimensionality of the value embeddings.
1172
+
1173
+ Returns:
1174
+ `torch.Tensor`:
1175
+ The downsampled mask tensor.
1176
+
1177
+ """
1178
+ o_h = mask.shape[1]
1179
+ o_w = mask.shape[2]
1180
+ ratio = o_w / o_h
1181
+ mask_h = int(math.sqrt(num_queries / ratio))
1182
+ mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
1183
+ mask_w = num_queries // mask_h
1184
+
1185
+ mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
1186
+
1187
+ # Repeat batch_size times
1188
+ if mask_downsample.shape[0] < batch_size:
1189
+ mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
1190
+
1191
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
1192
+
1193
+ downsampled_area = mask_h * mask_w
1194
+ # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
1195
+ # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
1196
+ if downsampled_area < num_queries:
1197
+ warnings.warn(
1198
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
1199
+ "Please update your masks or adjust the output size for optimal performance.",
1200
+ UserWarning,
1201
+ )
1202
+ mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
1203
+ # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
1204
+ if downsampled_area > num_queries:
1205
+ warnings.warn(
1206
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
1207
+ "Please update your masks or adjust the output size for optimal performance.",
1208
+ UserWarning,
1209
+ )
1210
+ mask_downsample = mask_downsample[:, :num_queries]
1211
+
1212
+ # Repeat last dimension to match SDPA output shape
1213
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
1214
+ 1, 1, value_embed_dim
1215
+ )
1216
+
1217
+ return mask_downsample
1218
+
1219
+
1220
+ class PixArtImageProcessor(VaeImageProcessor):
1221
+ """
1222
+ Image processor for PixArt image resize and crop.
1223
+
1224
+ Args:
1225
+ do_resize (`bool`, *optional*, defaults to `True`):
1226
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
1227
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
1228
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
1229
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
1230
+ resample (`str`, *optional*, defaults to `lanczos`):
1231
+ Resampling filter to use when resizing the image.
1232
+ do_normalize (`bool`, *optional*, defaults to `True`):
1233
+ Whether to normalize the image to [-1,1].
1234
+ do_binarize (`bool`, *optional*, defaults to `False`):
1235
+ Whether to binarize the image to 0/1.
1236
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
1237
+ Whether to convert the images to RGB format.
1238
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
1239
+ Whether to convert the images to grayscale format.
1240
+ """
1241
+
1242
+ @register_to_config
1243
+ def __init__(
1244
+ self,
1245
+ do_resize: bool = True,
1246
+ vae_scale_factor: int = 8,
1247
+ resample: str = "lanczos",
1248
+ do_normalize: bool = True,
1249
+ do_binarize: bool = False,
1250
+ do_convert_grayscale: bool = False,
1251
+ ):
1252
+ super().__init__(
1253
+ do_resize=do_resize,
1254
+ vae_scale_factor=vae_scale_factor,
1255
+ resample=resample,
1256
+ do_normalize=do_normalize,
1257
+ do_binarize=do_binarize,
1258
+ do_convert_grayscale=do_convert_grayscale,
1259
+ )
1260
+
1261
+ @staticmethod
1262
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
1263
+ r"""
1264
+ Returns the binned height and width based on the aspect ratio.
1265
+
1266
+ Args:
1267
+ height (`int`): The height of the image.
1268
+ width (`int`): The width of the image.
1269
+ ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
1270
+
1271
+ Returns:
1272
+ `Tuple[int, int]`: The closest binned height and width.
1273
+ """
1274
+ ar = float(height / width)
1275
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
1276
+ default_hw = ratios[closest_ratio]
1277
+ return int(default_hw[0]), int(default_hw[1])
1278
+
1279
+ @staticmethod
1280
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
1281
+ r"""
1282
+ Resizes and crops a tensor of images to the specified dimensions.
1283
+
1284
+ Args:
1285
+ samples (`torch.Tensor`):
1286
+ A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
1287
+ and W is the width.
1288
+ new_width (`int`): The desired width of the output images.
1289
+ new_height (`int`): The desired height of the output images.
1290
+
1291
+ Returns:
1292
+ `torch.Tensor`: A tensor containing the resized and cropped images.
1293
+ """
1294
+ orig_height, orig_width = samples.shape[2], samples.shape[3]
1295
+
1296
+ # Check if resizing is needed
1297
+ if orig_height != new_height or orig_width != new_width:
1298
+ ratio = max(new_height / orig_height, new_width / orig_width)
1299
+ resized_width = int(orig_width * ratio)
1300
+ resized_height = int(orig_height * ratio)
1301
+
1302
+ # Resize
1303
+ samples = F.interpolate(
1304
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
1305
+ )
1306
+
1307
+ # Center Crop
1308
+ start_x = (resized_width - new_width) // 2
1309
+ end_x = start_x + new_width
1310
+ start_y = (resized_height - new_height) // 2
1311
+ end_y = start_y + new_height
1312
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
1313
+
1314
+ return samples
icedit/diffusers/loaders/__init__.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
4
+ from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available
5
+
6
+
7
+ def text_encoder_lora_state_dict(text_encoder):
8
+ deprecate(
9
+ "text_encoder_load_state_dict in `models`",
10
+ "0.27.0",
11
+ "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
12
+ )
13
+ state_dict = {}
14
+
15
+ for name, module in text_encoder_attn_modules(text_encoder):
16
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
17
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
18
+
19
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
20
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
21
+
22
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
23
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
24
+
25
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
26
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
27
+
28
+ return state_dict
29
+
30
+
31
+ if is_transformers_available():
32
+
33
+ def text_encoder_attn_modules(text_encoder):
34
+ deprecate(
35
+ "text_encoder_attn_modules in `models`",
36
+ "0.27.0",
37
+ "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
38
+ )
39
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
40
+
41
+ attn_modules = []
42
+
43
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
44
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
45
+ name = f"text_model.encoder.layers.{i}.self_attn"
46
+ mod = layer.self_attn
47
+ attn_modules.append((name, mod))
48
+ else:
49
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
50
+
51
+ return attn_modules
52
+
53
+
54
+ _import_structure = {}
55
+
56
+ if is_torch_available():
57
+ _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58
+ _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
59
+ _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
60
+ _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
61
+ _import_structure["utils"] = ["AttnProcsLayers"]
62
+ if is_transformers_available():
63
+ _import_structure["single_file"] = ["FromSingleFileMixin"]
64
+ _import_structure["lora_pipeline"] = [
65
+ "AmusedLoraLoaderMixin",
66
+ "StableDiffusionLoraLoaderMixin",
67
+ "SD3LoraLoaderMixin",
68
+ "StableDiffusionXLLoraLoaderMixin",
69
+ "LTXVideoLoraLoaderMixin",
70
+ "LoraLoaderMixin",
71
+ "FluxLoraLoaderMixin",
72
+ "CogVideoXLoraLoaderMixin",
73
+ "Mochi1LoraLoaderMixin",
74
+ "HunyuanVideoLoraLoaderMixin",
75
+ "SanaLoraLoaderMixin",
76
+ ]
77
+ _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
78
+ _import_structure["ip_adapter"] = [
79
+ "IPAdapterMixin",
80
+ "FluxIPAdapterMixin",
81
+ "SD3IPAdapterMixin",
82
+ ]
83
+
84
+ _import_structure["peft"] = ["PeftAdapterMixin"]
85
+
86
+
87
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
88
+ if is_torch_available():
89
+ from .single_file_model import FromOriginalModelMixin
90
+ from .transformer_flux import FluxTransformer2DLoadersMixin
91
+ from .transformer_sd3 import SD3Transformer2DLoadersMixin
92
+ from .unet import UNet2DConditionLoadersMixin
93
+ from .utils import AttnProcsLayers
94
+
95
+ if is_transformers_available():
96
+ from .ip_adapter import (
97
+ FluxIPAdapterMixin,
98
+ IPAdapterMixin,
99
+ SD3IPAdapterMixin,
100
+ )
101
+ from .lora_pipeline import (
102
+ AmusedLoraLoaderMixin,
103
+ CogVideoXLoraLoaderMixin,
104
+ FluxLoraLoaderMixin,
105
+ HunyuanVideoLoraLoaderMixin,
106
+ LoraLoaderMixin,
107
+ LTXVideoLoraLoaderMixin,
108
+ Mochi1LoraLoaderMixin,
109
+ SanaLoraLoaderMixin,
110
+ SD3LoraLoaderMixin,
111
+ StableDiffusionLoraLoaderMixin,
112
+ StableDiffusionXLLoraLoaderMixin,
113
+ )
114
+ from .single_file import FromSingleFileMixin
115
+ from .textual_inversion import TextualInversionLoaderMixin
116
+
117
+ from .peft import PeftAdapterMixin
118
+ else:
119
+ import sys
120
+
121
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
icedit/diffusers/loaders/ip_adapter.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from huggingface_hub.utils import validate_hf_hub_args
21
+ from safetensors import safe_open
22
+
23
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
24
+ from ..utils import (
25
+ USE_PEFT_BACKEND,
26
+ _get_model_file,
27
+ is_accelerate_available,
28
+ is_torch_version,
29
+ is_transformers_available,
30
+ logging,
31
+ )
32
+ from .unet_loader_utils import _maybe_expand_lora_scales
33
+
34
+
35
+ if is_transformers_available():
36
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
37
+
38
+ from ..models.attention_processor import (
39
+ AttnProcessor,
40
+ AttnProcessor2_0,
41
+ FluxAttnProcessor2_0,
42
+ FluxIPAdapterJointAttnProcessor2_0,
43
+ IPAdapterAttnProcessor,
44
+ IPAdapterAttnProcessor2_0,
45
+ IPAdapterXFormersAttnProcessor,
46
+ JointAttnProcessor2_0,
47
+ SD3IPAdapterJointAttnProcessor2_0,
48
+ )
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ class IPAdapterMixin:
55
+ """Mixin for handling IP Adapters."""
56
+
57
+ @validate_hf_hub_args
58
+ def load_ip_adapter(
59
+ self,
60
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
61
+ subfolder: Union[str, List[str]],
62
+ weight_name: Union[str, List[str]],
63
+ image_encoder_folder: Optional[str] = "image_encoder",
64
+ **kwargs,
65
+ ):
66
+ """
67
+ Parameters:
68
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
69
+ Can be either:
70
+
71
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
72
+ the Hub.
73
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
74
+ with [`ModelMixin.save_pretrained`].
75
+ - A [torch state
76
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
77
+ subfolder (`str` or `List[str]`):
78
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
79
+ list is passed, it should have the same length as `weight_name`.
80
+ weight_name (`str` or `List[str]`):
81
+ The name of the weight file to load. If a list is passed, it should have the same length as
82
+ `subfolder`.
83
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
84
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
85
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
86
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
87
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
88
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
89
+ `image_encoder_folder="different_subfolder/image_encoder"`.
90
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
91
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
92
+ is not used.
93
+ force_download (`bool`, *optional*, defaults to `False`):
94
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
95
+ cached versions if they exist.
96
+
97
+ proxies (`Dict[str, str]`, *optional*):
98
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
99
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
100
+ local_files_only (`bool`, *optional*, defaults to `False`):
101
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
102
+ won't be downloaded from the Hub.
103
+ token (`str` or *bool*, *optional*):
104
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
105
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
106
+ revision (`str`, *optional*, defaults to `"main"`):
107
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
108
+ allowed by Git.
109
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
110
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
111
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
112
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
113
+ argument to `True` will raise an error.
114
+ """
115
+
116
+ # handle the list inputs for multiple IP Adapters
117
+ if not isinstance(weight_name, list):
118
+ weight_name = [weight_name]
119
+
120
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
121
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
122
+ if len(pretrained_model_name_or_path_or_dict) == 1:
123
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
124
+
125
+ if not isinstance(subfolder, list):
126
+ subfolder = [subfolder]
127
+ if len(subfolder) == 1:
128
+ subfolder = subfolder * len(weight_name)
129
+
130
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
131
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
132
+
133
+ if len(weight_name) != len(subfolder):
134
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
135
+
136
+ # Load the main state dict first.
137
+ cache_dir = kwargs.pop("cache_dir", None)
138
+ force_download = kwargs.pop("force_download", False)
139
+ proxies = kwargs.pop("proxies", None)
140
+ local_files_only = kwargs.pop("local_files_only", None)
141
+ token = kwargs.pop("token", None)
142
+ revision = kwargs.pop("revision", None)
143
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
144
+
145
+ if low_cpu_mem_usage and not is_accelerate_available():
146
+ low_cpu_mem_usage = False
147
+ logger.warning(
148
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
149
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
150
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
151
+ " install accelerate\n```\n."
152
+ )
153
+
154
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
155
+ raise NotImplementedError(
156
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
157
+ " `low_cpu_mem_usage=False`."
158
+ )
159
+
160
+ user_agent = {
161
+ "file_type": "attn_procs_weights",
162
+ "framework": "pytorch",
163
+ }
164
+ state_dicts = []
165
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
166
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
167
+ ):
168
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
169
+ model_file = _get_model_file(
170
+ pretrained_model_name_or_path_or_dict,
171
+ weights_name=weight_name,
172
+ cache_dir=cache_dir,
173
+ force_download=force_download,
174
+ proxies=proxies,
175
+ local_files_only=local_files_only,
176
+ token=token,
177
+ revision=revision,
178
+ subfolder=subfolder,
179
+ user_agent=user_agent,
180
+ )
181
+ if weight_name.endswith(".safetensors"):
182
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
183
+ with safe_open(model_file, framework="pt", device="cpu") as f:
184
+ for key in f.keys():
185
+ if key.startswith("image_proj."):
186
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
187
+ elif key.startswith("ip_adapter."):
188
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
189
+ else:
190
+ state_dict = load_state_dict(model_file)
191
+ else:
192
+ state_dict = pretrained_model_name_or_path_or_dict
193
+
194
+ keys = list(state_dict.keys())
195
+ if "image_proj" not in keys and "ip_adapter" not in keys:
196
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
197
+
198
+ state_dicts.append(state_dict)
199
+
200
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
201
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
202
+ if image_encoder_folder is not None:
203
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
204
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
205
+ if image_encoder_folder.count("/") == 0:
206
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
207
+ else:
208
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
209
+
210
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
211
+ pretrained_model_name_or_path_or_dict,
212
+ subfolder=image_encoder_subfolder,
213
+ low_cpu_mem_usage=low_cpu_mem_usage,
214
+ cache_dir=cache_dir,
215
+ local_files_only=local_files_only,
216
+ ).to(self.device, dtype=self.dtype)
217
+ self.register_modules(image_encoder=image_encoder)
218
+ else:
219
+ raise ValueError(
220
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
221
+ )
222
+ else:
223
+ logger.warning(
224
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
225
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
226
+ )
227
+
228
+ # create feature extractor if it has not been registered to the pipeline yet
229
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
230
+ # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
231
+ default_clip_size = 224
232
+ clip_image_size = (
233
+ self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
234
+ )
235
+ feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
236
+ self.register_modules(feature_extractor=feature_extractor)
237
+
238
+ # load ip-adapter into unet
239
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
240
+ unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
241
+
242
+ extra_loras = unet._load_ip_adapter_loras(state_dicts)
243
+ if extra_loras != {}:
244
+ if not USE_PEFT_BACKEND:
245
+ logger.warning("PEFT backend is required to load these weights.")
246
+ else:
247
+ # apply the IP Adapter Face ID LoRA weights
248
+ peft_config = getattr(unet, "peft_config", {})
249
+ for k, lora in extra_loras.items():
250
+ if f"faceid_{k}" not in peft_config:
251
+ self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
252
+ self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
253
+
254
+ def set_ip_adapter_scale(self, scale):
255
+ """
256
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
257
+ granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
258
+
259
+ Example:
260
+
261
+ ```py
262
+ # To use original IP-Adapter
263
+ scale = 1.0
264
+ pipeline.set_ip_adapter_scale(scale)
265
+
266
+ # To use style block only
267
+ scale = {
268
+ "up": {"block_0": [0.0, 1.0, 0.0]},
269
+ }
270
+ pipeline.set_ip_adapter_scale(scale)
271
+
272
+ # To use style+layout blocks
273
+ scale = {
274
+ "down": {"block_2": [0.0, 1.0]},
275
+ "up": {"block_0": [0.0, 1.0, 0.0]},
276
+ }
277
+ pipeline.set_ip_adapter_scale(scale)
278
+
279
+ # To use style and layout from 2 reference images
280
+ scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
281
+ pipeline.set_ip_adapter_scale(scales)
282
+ ```
283
+ """
284
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
285
+ if not isinstance(scale, list):
286
+ scale = [scale]
287
+ scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
288
+
289
+ for attn_name, attn_processor in unet.attn_processors.items():
290
+ if isinstance(
291
+ attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
292
+ ):
293
+ if len(scale_configs) != len(attn_processor.scale):
294
+ raise ValueError(
295
+ f"Cannot assign {len(scale_configs)} scale_configs to "
296
+ f"{len(attn_processor.scale)} IP-Adapter."
297
+ )
298
+ elif len(scale_configs) == 1:
299
+ scale_configs = scale_configs * len(attn_processor.scale)
300
+ for i, scale_config in enumerate(scale_configs):
301
+ if isinstance(scale_config, dict):
302
+ for k, s in scale_config.items():
303
+ if attn_name.startswith(k):
304
+ attn_processor.scale[i] = s
305
+ else:
306
+ attn_processor.scale[i] = scale_config
307
+
308
+ def unload_ip_adapter(self):
309
+ """
310
+ Unloads the IP Adapter weights
311
+
312
+ Examples:
313
+
314
+ ```python
315
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
316
+ >>> pipeline.unload_ip_adapter()
317
+ >>> ...
318
+ ```
319
+ """
320
+ # remove CLIP image encoder
321
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
322
+ self.image_encoder = None
323
+ self.register_to_config(image_encoder=[None, None])
324
+
325
+ # remove feature extractor only when safety_checker is None as safety_checker uses
326
+ # the feature_extractor later
327
+ if not hasattr(self, "safety_checker"):
328
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
329
+ self.feature_extractor = None
330
+ self.register_to_config(feature_extractor=[None, None])
331
+
332
+ # remove hidden encoder
333
+ self.unet.encoder_hid_proj = None
334
+ self.unet.config.encoder_hid_dim_type = None
335
+
336
+ # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
337
+ if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
338
+ self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
339
+ self.unet.text_encoder_hid_proj = None
340
+ self.unet.config.encoder_hid_dim_type = "text_proj"
341
+
342
+ # restore original Unet attention processors layers
343
+ attn_procs = {}
344
+ for name, value in self.unet.attn_processors.items():
345
+ attn_processor_class = (
346
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
347
+ )
348
+ attn_procs[name] = (
349
+ attn_processor_class
350
+ if isinstance(
351
+ value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
352
+ )
353
+ else value.__class__()
354
+ )
355
+ self.unet.set_attn_processor(attn_procs)
356
+
357
+
358
+ class FluxIPAdapterMixin:
359
+ """Mixin for handling Flux IP Adapters."""
360
+
361
+ @validate_hf_hub_args
362
+ def load_ip_adapter(
363
+ self,
364
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
365
+ weight_name: Union[str, List[str]],
366
+ subfolder: Optional[Union[str, List[str]]] = "",
367
+ image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
368
+ image_encoder_subfolder: Optional[str] = "",
369
+ image_encoder_dtype: torch.dtype = torch.float16,
370
+ **kwargs,
371
+ ):
372
+ """
373
+ Parameters:
374
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
375
+ Can be either:
376
+
377
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
378
+ the Hub.
379
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
380
+ with [`ModelMixin.save_pretrained`].
381
+ - A [torch state
382
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
383
+ subfolder (`str` or `List[str]`):
384
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
385
+ list is passed, it should have the same length as `weight_name`.
386
+ weight_name (`str` or `List[str]`):
387
+ The name of the weight file to load. If a list is passed, it should have the same length as
388
+ `weight_name`.
389
+ image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
390
+ Can be either:
391
+
392
+ - A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
393
+ hosted on the Hub.
394
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
395
+ with [`ModelMixin.save_pretrained`].
396
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
397
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
398
+ is not used.
399
+ force_download (`bool`, *optional*, defaults to `False`):
400
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
401
+ cached versions if they exist.
402
+
403
+ proxies (`Dict[str, str]`, *optional*):
404
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
405
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
406
+ local_files_only (`bool`, *optional*, defaults to `False`):
407
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
408
+ won't be downloaded from the Hub.
409
+ token (`str` or *bool*, *optional*):
410
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
411
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
412
+ revision (`str`, *optional*, defaults to `"main"`):
413
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
414
+ allowed by Git.
415
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
416
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
417
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
418
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
419
+ argument to `True` will raise an error.
420
+ """
421
+
422
+ # handle the list inputs for multiple IP Adapters
423
+ if not isinstance(weight_name, list):
424
+ weight_name = [weight_name]
425
+
426
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
427
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
428
+ if len(pretrained_model_name_or_path_or_dict) == 1:
429
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
430
+
431
+ if not isinstance(subfolder, list):
432
+ subfolder = [subfolder]
433
+ if len(subfolder) == 1:
434
+ subfolder = subfolder * len(weight_name)
435
+
436
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
437
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
438
+
439
+ if len(weight_name) != len(subfolder):
440
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
441
+
442
+ # Load the main state dict first.
443
+ cache_dir = kwargs.pop("cache_dir", None)
444
+ force_download = kwargs.pop("force_download", False)
445
+ proxies = kwargs.pop("proxies", None)
446
+ local_files_only = kwargs.pop("local_files_only", None)
447
+ token = kwargs.pop("token", None)
448
+ revision = kwargs.pop("revision", None)
449
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
450
+
451
+ if low_cpu_mem_usage and not is_accelerate_available():
452
+ low_cpu_mem_usage = False
453
+ logger.warning(
454
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
455
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
456
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
457
+ " install accelerate\n```\n."
458
+ )
459
+
460
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
461
+ raise NotImplementedError(
462
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
463
+ " `low_cpu_mem_usage=False`."
464
+ )
465
+
466
+ user_agent = {
467
+ "file_type": "attn_procs_weights",
468
+ "framework": "pytorch",
469
+ }
470
+ state_dicts = []
471
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
472
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
473
+ ):
474
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
475
+ model_file = _get_model_file(
476
+ pretrained_model_name_or_path_or_dict,
477
+ weights_name=weight_name,
478
+ cache_dir=cache_dir,
479
+ force_download=force_download,
480
+ proxies=proxies,
481
+ local_files_only=local_files_only,
482
+ token=token,
483
+ revision=revision,
484
+ subfolder=subfolder,
485
+ user_agent=user_agent,
486
+ )
487
+ if weight_name.endswith(".safetensors"):
488
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
489
+ with safe_open(model_file, framework="pt", device="cpu") as f:
490
+ image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
491
+ ip_adapter_keys = ["double_blocks.", "ip_adapter."]
492
+ for key in f.keys():
493
+ if any(key.startswith(prefix) for prefix in image_proj_keys):
494
+ diffusers_name = ".".join(key.split(".")[1:])
495
+ state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
496
+ elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
497
+ diffusers_name = (
498
+ ".".join(key.split(".")[1:])
499
+ .replace("ip_adapter_double_stream_k_proj", "to_k_ip")
500
+ .replace("ip_adapter_double_stream_v_proj", "to_v_ip")
501
+ .replace("processor.", "")
502
+ )
503
+ state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
504
+ else:
505
+ state_dict = load_state_dict(model_file)
506
+ else:
507
+ state_dict = pretrained_model_name_or_path_or_dict
508
+
509
+ keys = list(state_dict.keys())
510
+ if keys != ["image_proj", "ip_adapter"]:
511
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
512
+
513
+ state_dicts.append(state_dict)
514
+
515
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
516
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
517
+ if image_encoder_pretrained_model_name_or_path is not None:
518
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
519
+ logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
520
+ image_encoder = (
521
+ CLIPVisionModelWithProjection.from_pretrained(
522
+ image_encoder_pretrained_model_name_or_path,
523
+ subfolder=image_encoder_subfolder,
524
+ low_cpu_mem_usage=low_cpu_mem_usage,
525
+ cache_dir=cache_dir,
526
+ local_files_only=local_files_only,
527
+ )
528
+ .to(self.device, dtype=image_encoder_dtype)
529
+ .eval()
530
+ )
531
+ self.register_modules(image_encoder=image_encoder)
532
+ else:
533
+ raise ValueError(
534
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
535
+ )
536
+ else:
537
+ logger.warning(
538
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
539
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
540
+ )
541
+
542
+ # create feature extractor if it has not been registered to the pipeline yet
543
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
544
+ # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
545
+ default_clip_size = 224
546
+ clip_image_size = (
547
+ self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
548
+ )
549
+ feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
550
+ self.register_modules(feature_extractor=feature_extractor)
551
+
552
+ # load ip-adapter into transformer
553
+ self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
554
+
555
+ def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
556
+ """
557
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
558
+ granular control over each IP-Adapter behavior. A config can be a float or a list.
559
+
560
+ `float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
561
+ length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
562
+ number of IP adapters and each must match the number of blocks.
563
+
564
+ Example:
565
+
566
+ ```py
567
+ # To use original IP-Adapter
568
+ scale = 1.0
569
+ pipeline.set_ip_adapter_scale(scale)
570
+
571
+
572
+ def LinearStrengthModel(start, finish, size):
573
+ return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
574
+
575
+
576
+ ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
577
+ pipeline.set_ip_adapter_scale(ip_strengths)
578
+ ```
579
+ """
580
+ transformer = self.transformer
581
+ if not isinstance(scale, list):
582
+ scale = [[scale] * transformer.config.num_layers]
583
+ elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
584
+ if len(scale) != transformer.config.num_layers:
585
+ raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
586
+ scale = [scale]
587
+
588
+ scale_configs = scale
589
+
590
+ key_id = 0
591
+ for attn_name, attn_processor in transformer.attn_processors.items():
592
+ if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
593
+ if len(scale_configs) != len(attn_processor.scale):
594
+ raise ValueError(
595
+ f"Cannot assign {len(scale_configs)} scale_configs to "
596
+ f"{len(attn_processor.scale)} IP-Adapter."
597
+ )
598
+ elif len(scale_configs) == 1:
599
+ scale_configs = scale_configs * len(attn_processor.scale)
600
+ for i, scale_config in enumerate(scale_configs):
601
+ attn_processor.scale[i] = scale_config[key_id]
602
+ key_id += 1
603
+
604
+ def unload_ip_adapter(self):
605
+ """
606
+ Unloads the IP Adapter weights
607
+
608
+ Examples:
609
+
610
+ ```python
611
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
612
+ >>> pipeline.unload_ip_adapter()
613
+ >>> ...
614
+ ```
615
+ """
616
+ # remove CLIP image encoder
617
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
618
+ self.image_encoder = None
619
+ self.register_to_config(image_encoder=[None, None])
620
+
621
+ # remove feature extractor only when safety_checker is None as safety_checker uses
622
+ # the feature_extractor later
623
+ if not hasattr(self, "safety_checker"):
624
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
625
+ self.feature_extractor = None
626
+ self.register_to_config(feature_extractor=[None, None])
627
+
628
+ # remove hidden encoder
629
+ self.transformer.encoder_hid_proj = None
630
+ self.transformer.config.encoder_hid_dim_type = None
631
+
632
+ # restore original Transformer attention processors layers
633
+ attn_procs = {}
634
+ for name, value in self.transformer.attn_processors.items():
635
+ attn_processor_class = FluxAttnProcessor2_0()
636
+ attn_procs[name] = (
637
+ attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
638
+ )
639
+ self.transformer.set_attn_processor(attn_procs)
640
+
641
+
642
+ class SD3IPAdapterMixin:
643
+ """Mixin for handling StableDiffusion 3 IP Adapters."""
644
+
645
+ @property
646
+ def is_ip_adapter_active(self) -> bool:
647
+ """Checks if IP-Adapter is loaded and scale > 0.
648
+
649
+ IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
650
+ the image context is irrelevant.
651
+
652
+ Returns:
653
+ `bool`: True when IP-Adapter is loaded and any layer has scale > 0.
654
+ """
655
+ scales = [
656
+ attn_proc.scale
657
+ for attn_proc in self.transformer.attn_processors.values()
658
+ if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
659
+ ]
660
+
661
+ return len(scales) > 0 and any(scale > 0 for scale in scales)
662
+
663
+ @validate_hf_hub_args
664
+ def load_ip_adapter(
665
+ self,
666
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
667
+ weight_name: str = "ip-adapter.safetensors",
668
+ subfolder: Optional[str] = None,
669
+ image_encoder_folder: Optional[str] = "image_encoder",
670
+ **kwargs,
671
+ ) -> None:
672
+ """
673
+ Parameters:
674
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
675
+ Can be either:
676
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
677
+ the Hub.
678
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
679
+ with [`ModelMixin.save_pretrained`].
680
+ - A [torch state
681
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
682
+ weight_name (`str`, defaults to "ip-adapter.safetensors"):
683
+ The name of the weight file to load. If a list is passed, it should have the same length as
684
+ `subfolder`.
685
+ subfolder (`str`, *optional*):
686
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
687
+ list is passed, it should have the same length as `weight_name`.
688
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
689
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
690
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
691
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
692
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
693
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
694
+ `image_encoder_folder="different_subfolder/image_encoder"`.
695
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
696
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
697
+ is not used.
698
+ force_download (`bool`, *optional*, defaults to `False`):
699
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
700
+ cached versions if they exist.
701
+ proxies (`Dict[str, str]`, *optional*):
702
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
703
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
704
+ local_files_only (`bool`, *optional*, defaults to `False`):
705
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
706
+ won't be downloaded from the Hub.
707
+ token (`str` or *bool*, *optional*):
708
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
709
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
710
+ revision (`str`, *optional*, defaults to `"main"`):
711
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
712
+ allowed by Git.
713
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
714
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
715
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
716
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
717
+ argument to `True` will raise an error.
718
+ """
719
+ # Load the main state dict first
720
+ cache_dir = kwargs.pop("cache_dir", None)
721
+ force_download = kwargs.pop("force_download", False)
722
+ proxies = kwargs.pop("proxies", None)
723
+ local_files_only = kwargs.pop("local_files_only", None)
724
+ token = kwargs.pop("token", None)
725
+ revision = kwargs.pop("revision", None)
726
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
727
+
728
+ if low_cpu_mem_usage and not is_accelerate_available():
729
+ low_cpu_mem_usage = False
730
+ logger.warning(
731
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
732
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
733
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
734
+ " install accelerate\n```\n."
735
+ )
736
+
737
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
738
+ raise NotImplementedError(
739
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
740
+ " `low_cpu_mem_usage=False`."
741
+ )
742
+
743
+ user_agent = {
744
+ "file_type": "attn_procs_weights",
745
+ "framework": "pytorch",
746
+ }
747
+
748
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
749
+ model_file = _get_model_file(
750
+ pretrained_model_name_or_path_or_dict,
751
+ weights_name=weight_name,
752
+ cache_dir=cache_dir,
753
+ force_download=force_download,
754
+ proxies=proxies,
755
+ local_files_only=local_files_only,
756
+ token=token,
757
+ revision=revision,
758
+ subfolder=subfolder,
759
+ user_agent=user_agent,
760
+ )
761
+ if weight_name.endswith(".safetensors"):
762
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
763
+ with safe_open(model_file, framework="pt", device="cpu") as f:
764
+ for key in f.keys():
765
+ if key.startswith("image_proj."):
766
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
767
+ elif key.startswith("ip_adapter."):
768
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
769
+ else:
770
+ state_dict = load_state_dict(model_file)
771
+ else:
772
+ state_dict = pretrained_model_name_or_path_or_dict
773
+
774
+ keys = list(state_dict.keys())
775
+ if "image_proj" not in keys and "ip_adapter" not in keys:
776
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
777
+
778
+ # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
779
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
780
+ if image_encoder_folder is not None:
781
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
782
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
783
+ if image_encoder_folder.count("/") == 0:
784
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
785
+ else:
786
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
787
+
788
+ # Commons args for loading image encoder and image processor
789
+ kwargs = {
790
+ "low_cpu_mem_usage": low_cpu_mem_usage,
791
+ "cache_dir": cache_dir,
792
+ "local_files_only": local_files_only,
793
+ }
794
+
795
+ self.register_modules(
796
+ feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
797
+ self.device, dtype=self.dtype
798
+ ),
799
+ image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
800
+ self.device, dtype=self.dtype
801
+ ),
802
+ )
803
+ else:
804
+ raise ValueError(
805
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
806
+ )
807
+ else:
808
+ logger.warning(
809
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
810
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
811
+ )
812
+
813
+ # Load IP-Adapter into transformer
814
+ self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
815
+
816
+ def set_ip_adapter_scale(self, scale: float) -> None:
817
+ """
818
+ Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
819
+ conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
820
+ the model to produce more diverse images, but they may not be as aligned with the image prompt.
821
+
822
+ Example:
823
+
824
+ ```python
825
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
826
+ >>> pipeline.set_ip_adapter_scale(0.6)
827
+ >>> ...
828
+ ```
829
+
830
+ Args:
831
+ scale (float):
832
+ IP-Adapter scale to be set.
833
+
834
+ """
835
+ for attn_processor in self.transformer.attn_processors.values():
836
+ if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
837
+ attn_processor.scale = scale
838
+
839
+ def unload_ip_adapter(self) -> None:
840
+ """
841
+ Unloads the IP Adapter weights.
842
+
843
+ Example:
844
+
845
+ ```python
846
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
847
+ >>> pipeline.unload_ip_adapter()
848
+ >>> ...
849
+ ```
850
+ """
851
+ # Remove image encoder
852
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
853
+ self.image_encoder = None
854
+ self.register_to_config(image_encoder=None)
855
+
856
+ # Remove feature extractor
857
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
858
+ self.feature_extractor = None
859
+ self.register_to_config(feature_extractor=None)
860
+
861
+ # Remove image projection
862
+ self.transformer.image_proj = None
863
+
864
+ # Restore original attention processors layers
865
+ attn_procs = {
866
+ name: (
867
+ JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
868
+ )
869
+ for name, value in self.transformer.attn_processors.items()
870
+ }
871
+ self.transformer.set_attn_processor(attn_procs)
icedit/diffusers/loaders/lora_base.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Callable, Dict, List, Optional, Union
20
+
21
+ import safetensors
22
+ import torch
23
+ import torch.nn as nn
24
+ from huggingface_hub import model_info
25
+ from huggingface_hub.constants import HF_HUB_OFFLINE
26
+
27
+ from ..models.modeling_utils import ModelMixin, load_state_dict
28
+ from ..utils import (
29
+ USE_PEFT_BACKEND,
30
+ _get_model_file,
31
+ convert_state_dict_to_diffusers,
32
+ convert_state_dict_to_peft,
33
+ delete_adapter_layers,
34
+ deprecate,
35
+ get_adapter_name,
36
+ get_peft_kwargs,
37
+ is_accelerate_available,
38
+ is_peft_available,
39
+ is_peft_version,
40
+ is_transformers_available,
41
+ is_transformers_version,
42
+ logging,
43
+ recurse_remove_peft_layers,
44
+ scale_lora_layers,
45
+ set_adapter_layers,
46
+ set_weights_and_activate_adapters,
47
+ )
48
+
49
+
50
+ if is_transformers_available():
51
+ from transformers import PreTrainedModel
52
+
53
+ from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
54
+
55
+ if is_peft_available():
56
+ from peft.tuners.tuners_utils import BaseTunerLayer
57
+
58
+ if is_accelerate_available():
59
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
64
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
65
+
66
+
67
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
68
+ """
69
+ Fuses LoRAs for the text encoder.
70
+
71
+ Args:
72
+ text_encoder (`torch.nn.Module`):
73
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
74
+ attribute.
75
+ lora_scale (`float`, defaults to 1.0):
76
+ Controls how much to influence the outputs with the LoRA parameters.
77
+ safe_fusing (`bool`, defaults to `False`):
78
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
79
+ adapter_names (`List[str]` or `str`):
80
+ The names of the adapters to use.
81
+ """
82
+ merge_kwargs = {"safe_merge": safe_fusing}
83
+
84
+ for module in text_encoder.modules():
85
+ if isinstance(module, BaseTunerLayer):
86
+ if lora_scale != 1.0:
87
+ module.scale_layer(lora_scale)
88
+
89
+ # For BC with previous PEFT versions, we need to check the signature
90
+ # of the `merge` method to see if it supports the `adapter_names` argument.
91
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
92
+ if "adapter_names" in supported_merge_kwargs:
93
+ merge_kwargs["adapter_names"] = adapter_names
94
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
95
+ raise ValueError(
96
+ "The `adapter_names` argument is not supported with your PEFT version. "
97
+ "Please upgrade to the latest version of PEFT. `pip install -U peft`"
98
+ )
99
+
100
+ module.merge(**merge_kwargs)
101
+
102
+
103
+ def unfuse_text_encoder_lora(text_encoder):
104
+ """
105
+ Unfuses LoRAs for the text encoder.
106
+
107
+ Args:
108
+ text_encoder (`torch.nn.Module`):
109
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
110
+ attribute.
111
+ """
112
+ for module in text_encoder.modules():
113
+ if isinstance(module, BaseTunerLayer):
114
+ module.unmerge()
115
+
116
+
117
+ def set_adapters_for_text_encoder(
118
+ adapter_names: Union[List[str], str],
119
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
120
+ text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
121
+ ):
122
+ """
123
+ Sets the adapter layers for the text encoder.
124
+
125
+ Args:
126
+ adapter_names (`List[str]` or `str`):
127
+ The names of the adapters to use.
128
+ text_encoder (`torch.nn.Module`, *optional*):
129
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
130
+ attribute.
131
+ text_encoder_weights (`List[float]`, *optional*):
132
+ The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
133
+ """
134
+ if text_encoder is None:
135
+ raise ValueError(
136
+ "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
137
+ )
138
+
139
+ def process_weights(adapter_names, weights):
140
+ # Expand weights into a list, one entry per adapter
141
+ # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
142
+ if not isinstance(weights, list):
143
+ weights = [weights] * len(adapter_names)
144
+
145
+ if len(adapter_names) != len(weights):
146
+ raise ValueError(
147
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
148
+ )
149
+
150
+ # Set None values to default of 1.0
151
+ # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
152
+ weights = [w if w is not None else 1.0 for w in weights]
153
+
154
+ return weights
155
+
156
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
157
+ text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
158
+ set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
159
+
160
+
161
+ def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
162
+ """
163
+ Disables the LoRA layers for the text encoder.
164
+
165
+ Args:
166
+ text_encoder (`torch.nn.Module`, *optional*):
167
+ The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
168
+ attribute.
169
+ """
170
+ if text_encoder is None:
171
+ raise ValueError("Text Encoder not found.")
172
+ set_adapter_layers(text_encoder, enabled=False)
173
+
174
+
175
+ def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
176
+ """
177
+ Enables the LoRA layers for the text encoder.
178
+
179
+ Args:
180
+ text_encoder (`torch.nn.Module`, *optional*):
181
+ The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
182
+ attribute.
183
+ """
184
+ if text_encoder is None:
185
+ raise ValueError("Text Encoder not found.")
186
+ set_adapter_layers(text_encoder, enabled=True)
187
+
188
+
189
+ def _remove_text_encoder_monkey_patch(text_encoder):
190
+ recurse_remove_peft_layers(text_encoder)
191
+ if getattr(text_encoder, "peft_config", None) is not None:
192
+ del text_encoder.peft_config
193
+ text_encoder._hf_peft_config_loaded = None
194
+
195
+
196
+ def _fetch_state_dict(
197
+ pretrained_model_name_or_path_or_dict,
198
+ weight_name,
199
+ use_safetensors,
200
+ local_files_only,
201
+ cache_dir,
202
+ force_download,
203
+ proxies,
204
+ token,
205
+ revision,
206
+ subfolder,
207
+ user_agent,
208
+ allow_pickle,
209
+ ):
210
+ model_file = None
211
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
212
+ # Let's first try to load .safetensors weights
213
+ if (use_safetensors and weight_name is None) or (
214
+ weight_name is not None and weight_name.endswith(".safetensors")
215
+ ):
216
+ try:
217
+ # Here we're relaxing the loading check to enable more Inference API
218
+ # friendliness where sometimes, it's not at all possible to automatically
219
+ # determine `weight_name`.
220
+ if weight_name is None:
221
+ weight_name = _best_guess_weight_name(
222
+ pretrained_model_name_or_path_or_dict,
223
+ file_extension=".safetensors",
224
+ local_files_only=local_files_only,
225
+ )
226
+ model_file = _get_model_file(
227
+ pretrained_model_name_or_path_or_dict,
228
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
229
+ cache_dir=cache_dir,
230
+ force_download=force_download,
231
+ proxies=proxies,
232
+ local_files_only=local_files_only,
233
+ token=token,
234
+ revision=revision,
235
+ subfolder=subfolder,
236
+ user_agent=user_agent,
237
+ )
238
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
239
+ except (IOError, safetensors.SafetensorError) as e:
240
+ if not allow_pickle:
241
+ raise e
242
+ # try loading non-safetensors weights
243
+ model_file = None
244
+ pass
245
+
246
+ if model_file is None:
247
+ if weight_name is None:
248
+ weight_name = _best_guess_weight_name(
249
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
250
+ )
251
+ model_file = _get_model_file(
252
+ pretrained_model_name_or_path_or_dict,
253
+ weights_name=weight_name or LORA_WEIGHT_NAME,
254
+ cache_dir=cache_dir,
255
+ force_download=force_download,
256
+ proxies=proxies,
257
+ local_files_only=local_files_only,
258
+ token=token,
259
+ revision=revision,
260
+ subfolder=subfolder,
261
+ user_agent=user_agent,
262
+ )
263
+ state_dict = load_state_dict(model_file)
264
+ else:
265
+ state_dict = pretrained_model_name_or_path_or_dict
266
+
267
+ return state_dict
268
+
269
+
270
+ def _best_guess_weight_name(
271
+ pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
272
+ ):
273
+ if local_files_only or HF_HUB_OFFLINE:
274
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
275
+
276
+ targeted_files = []
277
+
278
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
279
+ return
280
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
281
+ targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
282
+ else:
283
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
284
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
285
+ if len(targeted_files) == 0:
286
+ return
287
+
288
+ # "scheduler" does not correspond to a LoRA checkpoint.
289
+ # "optimizer" does not correspond to a LoRA checkpoint
290
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
291
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
292
+ targeted_files = list(
293
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
294
+ )
295
+
296
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
297
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
298
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
299
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
300
+
301
+ if len(targeted_files) > 1:
302
+ raise ValueError(
303
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
304
+ )
305
+ weight_name = targeted_files[0]
306
+ return weight_name
307
+
308
+
309
+ def _load_lora_into_text_encoder(
310
+ state_dict,
311
+ network_alphas,
312
+ text_encoder,
313
+ prefix=None,
314
+ lora_scale=1.0,
315
+ text_encoder_name="text_encoder",
316
+ adapter_name=None,
317
+ _pipeline=None,
318
+ low_cpu_mem_usage=False,
319
+ ):
320
+ if not USE_PEFT_BACKEND:
321
+ raise ValueError("PEFT backend is required for this method.")
322
+
323
+ peft_kwargs = {}
324
+ if low_cpu_mem_usage:
325
+ if not is_peft_version(">=", "0.13.1"):
326
+ raise ValueError(
327
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
328
+ )
329
+ if not is_transformers_version(">", "4.45.2"):
330
+ # Note from sayakpaul: It's not in `transformers` stable yet.
331
+ # https://github.com/huggingface/transformers/pull/33725/
332
+ raise ValueError(
333
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
334
+ )
335
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
336
+
337
+ from peft import LoraConfig
338
+
339
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
340
+ # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
341
+ # their prefixes.
342
+ keys = list(state_dict.keys())
343
+ prefix = text_encoder_name if prefix is None else prefix
344
+
345
+ # Safe prefix to check with.
346
+ if any(text_encoder_name in key for key in keys):
347
+ # Load the layers corresponding to text encoder and make necessary adjustments.
348
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
349
+ text_encoder_lora_state_dict = {
350
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
351
+ }
352
+
353
+ if len(text_encoder_lora_state_dict) > 0:
354
+ logger.info(f"Loading {prefix}.")
355
+ rank = {}
356
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
357
+
358
+ # convert state dict
359
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
360
+
361
+ for name, _ in text_encoder_attn_modules(text_encoder):
362
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
363
+ rank_key = f"{name}.{module}.lora_B.weight"
364
+ if rank_key not in text_encoder_lora_state_dict:
365
+ continue
366
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
367
+
368
+ for name, _ in text_encoder_mlp_modules(text_encoder):
369
+ for module in ("fc1", "fc2"):
370
+ rank_key = f"{name}.{module}.lora_B.weight"
371
+ if rank_key not in text_encoder_lora_state_dict:
372
+ continue
373
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
374
+
375
+ if network_alphas is not None:
376
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
378
+
379
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
380
+
381
+ if "use_dora" in lora_config_kwargs:
382
+ if lora_config_kwargs["use_dora"]:
383
+ if is_peft_version("<", "0.9.0"):
384
+ raise ValueError(
385
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386
+ )
387
+ else:
388
+ if is_peft_version("<", "0.9.0"):
389
+ lora_config_kwargs.pop("use_dora")
390
+
391
+ if "lora_bias" in lora_config_kwargs:
392
+ if lora_config_kwargs["lora_bias"]:
393
+ if is_peft_version("<=", "0.13.2"):
394
+ raise ValueError(
395
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396
+ )
397
+ else:
398
+ if is_peft_version("<=", "0.13.2"):
399
+ lora_config_kwargs.pop("lora_bias")
400
+
401
+ lora_config = LoraConfig(**lora_config_kwargs)
402
+
403
+ # adapter_name
404
+ if adapter_name is None:
405
+ adapter_name = get_adapter_name(text_encoder)
406
+
407
+ is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
408
+
409
+ # inject LoRA layers and load the state dict
410
+ # in transformers we automatically check whether the adapter name is already in use or not
411
+ text_encoder.load_adapter(
412
+ adapter_name=adapter_name,
413
+ adapter_state_dict=text_encoder_lora_state_dict,
414
+ peft_config=lora_config,
415
+ **peft_kwargs,
416
+ )
417
+
418
+ # scale LoRA layers with `lora_scale`
419
+ scale_lora_layers(text_encoder, weight=lora_scale)
420
+
421
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
422
+
423
+ # Offload back.
424
+ if is_model_cpu_offload:
425
+ _pipeline.enable_model_cpu_offload()
426
+ elif is_sequential_cpu_offload:
427
+ _pipeline.enable_sequential_cpu_offload()
428
+ # Unsafe code />
429
+
430
+
431
+ def _func_optionally_disable_offloading(_pipeline):
432
+ is_model_cpu_offload = False
433
+ is_sequential_cpu_offload = False
434
+
435
+ if _pipeline is not None and _pipeline.hf_device_map is None:
436
+ for _, component in _pipeline.components.items():
437
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
438
+ if not is_model_cpu_offload:
439
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
440
+ if not is_sequential_cpu_offload:
441
+ is_sequential_cpu_offload = (
442
+ isinstance(component._hf_hook, AlignDevicesHook)
443
+ or hasattr(component._hf_hook, "hooks")
444
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
445
+ )
446
+
447
+ logger.info(
448
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
449
+ )
450
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
451
+
452
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
453
+
454
+
455
+ class LoraBaseMixin:
456
+ """Utility class for handling LoRAs."""
457
+
458
+ _lora_loadable_modules = []
459
+ num_fused_loras = 0
460
+
461
+ def load_lora_weights(self, **kwargs):
462
+ raise NotImplementedError("`load_lora_weights()` is not implemented.")
463
+
464
+ @classmethod
465
+ def save_lora_weights(cls, **kwargs):
466
+ raise NotImplementedError("`save_lora_weights()` not implemented.")
467
+
468
+ @classmethod
469
+ def lora_state_dict(cls, **kwargs):
470
+ raise NotImplementedError("`lora_state_dict()` is not implemented.")
471
+
472
+ @classmethod
473
+ def _optionally_disable_offloading(cls, _pipeline):
474
+ """
475
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
476
+
477
+ Args:
478
+ _pipeline (`DiffusionPipeline`):
479
+ The pipeline to disable offloading for.
480
+
481
+ Returns:
482
+ tuple:
483
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
484
+ """
485
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
486
+
487
+ @classmethod
488
+ def _fetch_state_dict(cls, *args, **kwargs):
489
+ deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
490
+ deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
491
+ return _fetch_state_dict(*args, **kwargs)
492
+
493
+ @classmethod
494
+ def _best_guess_weight_name(cls, *args, **kwargs):
495
+ deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
496
+ deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
497
+ return _best_guess_weight_name(*args, **kwargs)
498
+
499
+ def unload_lora_weights(self):
500
+ """
501
+ Unloads the LoRA parameters.
502
+
503
+ Examples:
504
+
505
+ ```python
506
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
507
+ >>> pipeline.unload_lora_weights()
508
+ >>> ...
509
+ ```
510
+ """
511
+ if not USE_PEFT_BACKEND:
512
+ raise ValueError("PEFT backend is required for this method.")
513
+
514
+ for component in self._lora_loadable_modules:
515
+ model = getattr(self, component, None)
516
+ if model is not None:
517
+ if issubclass(model.__class__, ModelMixin):
518
+ model.unload_lora()
519
+ elif issubclass(model.__class__, PreTrainedModel):
520
+ _remove_text_encoder_monkey_patch(model)
521
+
522
+ def fuse_lora(
523
+ self,
524
+ components: List[str] = [],
525
+ lora_scale: float = 1.0,
526
+ safe_fusing: bool = False,
527
+ adapter_names: Optional[List[str]] = None,
528
+ **kwargs,
529
+ ):
530
+ r"""
531
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
532
+
533
+ <Tip warning={true}>
534
+
535
+ This is an experimental API.
536
+
537
+ </Tip>
538
+
539
+ Args:
540
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
541
+ lora_scale (`float`, defaults to 1.0):
542
+ Controls how much to influence the outputs with the LoRA parameters.
543
+ safe_fusing (`bool`, defaults to `False`):
544
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
545
+ adapter_names (`List[str]`, *optional*):
546
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
547
+
548
+ Example:
549
+
550
+ ```py
551
+ from diffusers import DiffusionPipeline
552
+ import torch
553
+
554
+ pipeline = DiffusionPipeline.from_pretrained(
555
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
556
+ ).to("cuda")
557
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
558
+ pipeline.fuse_lora(lora_scale=0.7)
559
+ ```
560
+ """
561
+ if "fuse_unet" in kwargs:
562
+ depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
563
+ deprecate(
564
+ "fuse_unet",
565
+ "1.0.0",
566
+ depr_message,
567
+ )
568
+ if "fuse_transformer" in kwargs:
569
+ depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
570
+ deprecate(
571
+ "fuse_transformer",
572
+ "1.0.0",
573
+ depr_message,
574
+ )
575
+ if "fuse_text_encoder" in kwargs:
576
+ depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
577
+ deprecate(
578
+ "fuse_text_encoder",
579
+ "1.0.0",
580
+ depr_message,
581
+ )
582
+
583
+ if len(components) == 0:
584
+ raise ValueError("`components` cannot be an empty list.")
585
+
586
+ for fuse_component in components:
587
+ if fuse_component not in self._lora_loadable_modules:
588
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
589
+
590
+ model = getattr(self, fuse_component, None)
591
+ if model is not None:
592
+ # check if diffusers model
593
+ if issubclass(model.__class__, ModelMixin):
594
+ model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
595
+ # handle transformers models.
596
+ if issubclass(model.__class__, PreTrainedModel):
597
+ fuse_text_encoder_lora(
598
+ model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
599
+ )
600
+
601
+ self.num_fused_loras += 1
602
+
603
+ def unfuse_lora(self, components: List[str] = [], **kwargs):
604
+ r"""
605
+ Reverses the effect of
606
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
607
+
608
+ <Tip warning={true}>
609
+
610
+ This is an experimental API.
611
+
612
+ </Tip>
613
+
614
+ Args:
615
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
616
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
617
+ unfuse_text_encoder (`bool`, defaults to `True`):
618
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
619
+ LoRA parameters then it won't have any effect.
620
+ """
621
+ if "unfuse_unet" in kwargs:
622
+ depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
623
+ deprecate(
624
+ "unfuse_unet",
625
+ "1.0.0",
626
+ depr_message,
627
+ )
628
+ if "unfuse_transformer" in kwargs:
629
+ depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
630
+ deprecate(
631
+ "unfuse_transformer",
632
+ "1.0.0",
633
+ depr_message,
634
+ )
635
+ if "unfuse_text_encoder" in kwargs:
636
+ depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
637
+ deprecate(
638
+ "unfuse_text_encoder",
639
+ "1.0.0",
640
+ depr_message,
641
+ )
642
+
643
+ if len(components) == 0:
644
+ raise ValueError("`components` cannot be an empty list.")
645
+
646
+ for fuse_component in components:
647
+ if fuse_component not in self._lora_loadable_modules:
648
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
649
+
650
+ model = getattr(self, fuse_component, None)
651
+ if model is not None:
652
+ if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
653
+ for module in model.modules():
654
+ if isinstance(module, BaseTunerLayer):
655
+ module.unmerge()
656
+
657
+ self.num_fused_loras -= 1
658
+
659
+ def set_adapters(
660
+ self,
661
+ adapter_names: Union[List[str], str],
662
+ adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
663
+ ):
664
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
665
+
666
+ adapter_weights = copy.deepcopy(adapter_weights)
667
+
668
+ # Expand weights into a list, one entry per adapter
669
+ if not isinstance(adapter_weights, list):
670
+ adapter_weights = [adapter_weights] * len(adapter_names)
671
+
672
+ if len(adapter_names) != len(adapter_weights):
673
+ raise ValueError(
674
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
675
+ )
676
+
677
+ list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
678
+ # eg ["adapter1", "adapter2"]
679
+ all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
680
+ missing_adapters = set(adapter_names) - all_adapters
681
+ if len(missing_adapters) > 0:
682
+ raise ValueError(
683
+ f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
684
+ )
685
+
686
+ # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
687
+ invert_list_adapters = {
688
+ adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
689
+ for adapter in all_adapters
690
+ }
691
+
692
+ # Decompose weights into weights for denoiser and text encoders.
693
+ _component_adapter_weights = {}
694
+ for component in self._lora_loadable_modules:
695
+ model = getattr(self, component)
696
+
697
+ for adapter_name, weights in zip(adapter_names, adapter_weights):
698
+ if isinstance(weights, dict):
699
+ component_adapter_weights = weights.pop(component, None)
700
+
701
+ if component_adapter_weights is not None and not hasattr(self, component):
702
+ logger.warning(
703
+ f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
704
+ )
705
+
706
+ if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
707
+ logger.warning(
708
+ (
709
+ f"Lora weight dict for adapter '{adapter_name}' contains {component},"
710
+ f"but this will be ignored because {adapter_name} does not contain weights for {component}."
711
+ f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
712
+ )
713
+ )
714
+
715
+ else:
716
+ component_adapter_weights = weights
717
+
718
+ _component_adapter_weights.setdefault(component, [])
719
+ _component_adapter_weights[component].append(component_adapter_weights)
720
+
721
+ if issubclass(model.__class__, ModelMixin):
722
+ model.set_adapters(adapter_names, _component_adapter_weights[component])
723
+ elif issubclass(model.__class__, PreTrainedModel):
724
+ set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
725
+
726
+ def disable_lora(self):
727
+ if not USE_PEFT_BACKEND:
728
+ raise ValueError("PEFT backend is required for this method.")
729
+
730
+ for component in self._lora_loadable_modules:
731
+ model = getattr(self, component, None)
732
+ if model is not None:
733
+ if issubclass(model.__class__, ModelMixin):
734
+ model.disable_lora()
735
+ elif issubclass(model.__class__, PreTrainedModel):
736
+ disable_lora_for_text_encoder(model)
737
+
738
+ def enable_lora(self):
739
+ if not USE_PEFT_BACKEND:
740
+ raise ValueError("PEFT backend is required for this method.")
741
+
742
+ for component in self._lora_loadable_modules:
743
+ model = getattr(self, component, None)
744
+ if model is not None:
745
+ if issubclass(model.__class__, ModelMixin):
746
+ model.enable_lora()
747
+ elif issubclass(model.__class__, PreTrainedModel):
748
+ enable_lora_for_text_encoder(model)
749
+
750
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
751
+ """
752
+ Args:
753
+ Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
754
+ adapter_names (`Union[List[str], str]`):
755
+ The names of the adapter to delete. Can be a single string or a list of strings
756
+ """
757
+ if not USE_PEFT_BACKEND:
758
+ raise ValueError("PEFT backend is required for this method.")
759
+
760
+ if isinstance(adapter_names, str):
761
+ adapter_names = [adapter_names]
762
+
763
+ for component in self._lora_loadable_modules:
764
+ model = getattr(self, component, None)
765
+ if model is not None:
766
+ if issubclass(model.__class__, ModelMixin):
767
+ model.delete_adapters(adapter_names)
768
+ elif issubclass(model.__class__, PreTrainedModel):
769
+ for adapter_name in adapter_names:
770
+ delete_adapter_layers(model, adapter_name)
771
+
772
+ def get_active_adapters(self) -> List[str]:
773
+ """
774
+ Gets the list of the current active adapters.
775
+
776
+ Example:
777
+
778
+ ```python
779
+ from diffusers import DiffusionPipeline
780
+
781
+ pipeline = DiffusionPipeline.from_pretrained(
782
+ "stabilityai/stable-diffusion-xl-base-1.0",
783
+ ).to("cuda")
784
+ pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
785
+ pipeline.get_active_adapters()
786
+ ```
787
+ """
788
+ if not USE_PEFT_BACKEND:
789
+ raise ValueError(
790
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
791
+ )
792
+
793
+ active_adapters = []
794
+
795
+ for component in self._lora_loadable_modules:
796
+ model = getattr(self, component, None)
797
+ if model is not None and issubclass(model.__class__, ModelMixin):
798
+ for module in model.modules():
799
+ if isinstance(module, BaseTunerLayer):
800
+ active_adapters = module.active_adapters
801
+ break
802
+
803
+ return active_adapters
804
+
805
+ def get_list_adapters(self) -> Dict[str, List[str]]:
806
+ """
807
+ Gets the current list of all available adapters in the pipeline.
808
+ """
809
+ if not USE_PEFT_BACKEND:
810
+ raise ValueError(
811
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
812
+ )
813
+
814
+ set_adapters = {}
815
+
816
+ for component in self._lora_loadable_modules:
817
+ model = getattr(self, component, None)
818
+ if (
819
+ model is not None
820
+ and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
821
+ and hasattr(model, "peft_config")
822
+ ):
823
+ set_adapters[component] = list(model.peft_config.keys())
824
+
825
+ return set_adapters
826
+
827
+ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
828
+ """
829
+ Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
830
+ you want to load multiple adapters and free some GPU memory.
831
+
832
+ Args:
833
+ adapter_names (`List[str]`):
834
+ List of adapters to send device to.
835
+ device (`Union[torch.device, str, int]`):
836
+ Device to send the adapters to. Can be either a torch device, a str or an integer.
837
+ """
838
+ if not USE_PEFT_BACKEND:
839
+ raise ValueError("PEFT backend is required for this method.")
840
+
841
+ for component in self._lora_loadable_modules:
842
+ model = getattr(self, component, None)
843
+ if model is not None:
844
+ for module in model.modules():
845
+ if isinstance(module, BaseTunerLayer):
846
+ for adapter_name in adapter_names:
847
+ module.lora_A[adapter_name].to(device)
848
+ module.lora_B[adapter_name].to(device)
849
+ # this is a param, not a module, so device placement is not in-place -> re-assign
850
+ if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
851
+ if adapter_name in module.lora_magnitude_vector:
852
+ module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
853
+ adapter_name
854
+ ].to(device)
855
+
856
+ @staticmethod
857
+ def pack_weights(layers, prefix):
858
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
859
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
860
+ return layers_state_dict
861
+
862
+ @staticmethod
863
+ def write_lora_layers(
864
+ state_dict: Dict[str, torch.Tensor],
865
+ save_directory: str,
866
+ is_main_process: bool,
867
+ weight_name: str,
868
+ save_function: Callable,
869
+ safe_serialization: bool,
870
+ ):
871
+ if os.path.isfile(save_directory):
872
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
873
+ return
874
+
875
+ if save_function is None:
876
+ if safe_serialization:
877
+
878
+ def save_function(weights, filename):
879
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
880
+
881
+ else:
882
+ save_function = torch.save
883
+
884
+ os.makedirs(save_directory, exist_ok=True)
885
+
886
+ if weight_name is None:
887
+ if safe_serialization:
888
+ weight_name = LORA_WEIGHT_NAME_SAFE
889
+ else:
890
+ weight_name = LORA_WEIGHT_NAME
891
+
892
+ save_path = Path(save_directory, weight_name).as_posix()
893
+ save_function(state_dict, save_path)
894
+ logger.info(f"Model weights saved in {save_path}")
895
+
896
+ @property
897
+ def lora_scale(self) -> float:
898
+ # property function that returns the lora scale which can be set at run time by the pipeline.
899
+ # if _lora_scale has not been set, return 1
900
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
icedit/diffusers/loaders/lora_conversion_utils.py ADDED
@@ -0,0 +1,1150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+
17
+ import torch
18
+
19
+ from ..utils import is_peft_version, logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
26
+ # 1. get all state_dict_keys
27
+ all_keys = list(state_dict.keys())
28
+ sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
29
+
30
+ # 2. check if needs remapping, if not return original dict
31
+ is_in_sgm_format = False
32
+ for key in all_keys:
33
+ if any(p in key for p in sgm_patterns):
34
+ is_in_sgm_format = True
35
+ break
36
+
37
+ if not is_in_sgm_format:
38
+ return state_dict
39
+
40
+ # 3. Else remap from SGM patterns
41
+ new_state_dict = {}
42
+ inner_block_map = ["resnets", "attentions", "upsamplers"]
43
+
44
+ # Retrieves # of down, mid and up blocks
45
+ input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
46
+
47
+ for layer in all_keys:
48
+ if "text" in layer:
49
+ new_state_dict[layer] = state_dict.pop(layer)
50
+ else:
51
+ layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
52
+ if sgm_patterns[0] in layer:
53
+ input_block_ids.add(layer_id)
54
+ elif sgm_patterns[1] in layer:
55
+ middle_block_ids.add(layer_id)
56
+ elif sgm_patterns[2] in layer:
57
+ output_block_ids.add(layer_id)
58
+ else:
59
+ raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
60
+
61
+ input_blocks = {
62
+ layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
63
+ for layer_id in input_block_ids
64
+ }
65
+ middle_blocks = {
66
+ layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
67
+ for layer_id in middle_block_ids
68
+ }
69
+ output_blocks = {
70
+ layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
71
+ for layer_id in output_block_ids
72
+ }
73
+
74
+ # Rename keys accordingly
75
+ for i in input_block_ids:
76
+ block_id = (i - 1) // (unet_config.layers_per_block + 1)
77
+ layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
78
+
79
+ for key in input_blocks[i]:
80
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
81
+ inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
82
+ inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
83
+ new_key = delimiter.join(
84
+ key.split(delimiter)[: block_slice_pos - 1]
85
+ + [str(block_id), inner_block_key, inner_layers_in_block]
86
+ + key.split(delimiter)[block_slice_pos + 1 :]
87
+ )
88
+ new_state_dict[new_key] = state_dict.pop(key)
89
+
90
+ for i in middle_block_ids:
91
+ key_part = None
92
+ if i == 0:
93
+ key_part = [inner_block_map[0], "0"]
94
+ elif i == 1:
95
+ key_part = [inner_block_map[1], "0"]
96
+ elif i == 2:
97
+ key_part = [inner_block_map[0], "1"]
98
+ else:
99
+ raise ValueError(f"Invalid middle block id {i}.")
100
+
101
+ for key in middle_blocks[i]:
102
+ new_key = delimiter.join(
103
+ key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
104
+ )
105
+ new_state_dict[new_key] = state_dict.pop(key)
106
+
107
+ for i in output_block_ids:
108
+ block_id = i // (unet_config.layers_per_block + 1)
109
+ layer_in_block_id = i % (unet_config.layers_per_block + 1)
110
+
111
+ for key in output_blocks[i]:
112
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
113
+ inner_block_key = inner_block_map[inner_block_id]
114
+ inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
115
+ new_key = delimiter.join(
116
+ key.split(delimiter)[: block_slice_pos - 1]
117
+ + [str(block_id), inner_block_key, inner_layers_in_block]
118
+ + key.split(delimiter)[block_slice_pos + 1 :]
119
+ )
120
+ new_state_dict[new_key] = state_dict.pop(key)
121
+
122
+ if len(state_dict) > 0:
123
+ raise ValueError("At this point all state dict entries have to be converted.")
124
+
125
+ return new_state_dict
126
+
127
+
128
+ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
129
+ """
130
+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
131
+
132
+ Args:
133
+ state_dict (`dict`): The state dict to convert.
134
+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
135
+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
136
+ "text_encoder".
137
+
138
+ Returns:
139
+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
140
+ """
141
+ unet_state_dict = {}
142
+ te_state_dict = {}
143
+ te2_state_dict = {}
144
+ network_alphas = {}
145
+
146
+ # Check for DoRA-enabled LoRAs.
147
+ dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
148
+ dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
149
+ dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
150
+ if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
151
+ if is_peft_version("<", "0.9.0"):
152
+ raise ValueError(
153
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
154
+ )
155
+
156
+ # Iterate over all LoRA weights.
157
+ all_lora_keys = list(state_dict.keys())
158
+ for key in all_lora_keys:
159
+ if not key.endswith("lora_down.weight"):
160
+ continue
161
+
162
+ # Extract LoRA name.
163
+ lora_name = key.split(".")[0]
164
+
165
+ # Find corresponding up weight and alpha.
166
+ lora_name_up = lora_name + ".lora_up.weight"
167
+ lora_name_alpha = lora_name + ".alpha"
168
+
169
+ # Handle U-Net LoRAs.
170
+ if lora_name.startswith("lora_unet_"):
171
+ diffusers_name = _convert_unet_lora_key(key)
172
+
173
+ # Store down and up weights.
174
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
175
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
176
+
177
+ # Store DoRA scale if present.
178
+ if dora_present_in_unet:
179
+ dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
180
+ unet_state_dict[
181
+ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
182
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
183
+
184
+ # Handle text encoder LoRAs.
185
+ elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
186
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
187
+
188
+ # Store down and up weights for te or te2.
189
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
190
+ te_state_dict[diffusers_name] = state_dict.pop(key)
191
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
192
+ else:
193
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
194
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
195
+
196
+ # Store DoRA scale if present.
197
+ if dora_present_in_te or dora_present_in_te2:
198
+ dora_scale_key_to_replace_te = (
199
+ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
200
+ )
201
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
202
+ te_state_dict[
203
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
204
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
205
+ elif lora_name.startswith("lora_te2_"):
206
+ te2_state_dict[
207
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
208
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
209
+
210
+ # Store alpha if present.
211
+ if lora_name_alpha in state_dict:
212
+ alpha = state_dict.pop(lora_name_alpha).item()
213
+ network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
214
+
215
+ # Check if any keys remain.
216
+ if len(state_dict) > 0:
217
+ raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
218
+
219
+ logger.info("Non-diffusers checkpoint detected.")
220
+
221
+ # Construct final state dict.
222
+ unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
223
+ te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
224
+ te2_state_dict = (
225
+ {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
226
+ if len(te2_state_dict) > 0
227
+ else None
228
+ )
229
+ if te2_state_dict is not None:
230
+ te_state_dict.update(te2_state_dict)
231
+
232
+ new_state_dict = {**unet_state_dict, **te_state_dict}
233
+ return new_state_dict, network_alphas
234
+
235
+
236
+ def _convert_unet_lora_key(key):
237
+ """
238
+ Converts a U-Net LoRA key to a Diffusers compatible key.
239
+ """
240
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
241
+
242
+ # Replace common U-Net naming patterns.
243
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
244
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
245
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
246
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
247
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
248
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
249
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
250
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
251
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
252
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
253
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
254
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
255
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
256
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
257
+
258
+ # SDXL specific conversions.
259
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
260
+ pattern = r"\.\d+(?=\D*$)"
261
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
262
+ if ".in." in diffusers_name:
263
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
264
+ if ".out." in diffusers_name:
265
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
266
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
267
+ diffusers_name = diffusers_name.replace("op", "conv")
268
+ if "skip" in diffusers_name:
269
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
270
+
271
+ # LyCORIS specific conversions.
272
+ if "time.emb.proj" in diffusers_name:
273
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
274
+ if "conv.shortcut" in diffusers_name:
275
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
276
+
277
+ # General conversions.
278
+ if "transformer_blocks" in diffusers_name:
279
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
280
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
281
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
282
+ elif "ff" in diffusers_name:
283
+ pass
284
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
285
+ pass
286
+ else:
287
+ pass
288
+
289
+ return diffusers_name
290
+
291
+
292
+ def _convert_text_encoder_lora_key(key, lora_name):
293
+ """
294
+ Converts a text encoder LoRA key to a Diffusers compatible key.
295
+ """
296
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
297
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
298
+ else:
299
+ key_to_replace = "lora_te2_"
300
+
301
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
302
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
303
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
304
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
305
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
306
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
307
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
308
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
309
+
310
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
311
+ pass
312
+ elif "mlp" in diffusers_name:
313
+ # Be aware that this is the new diffusers convention and the rest of the code might
314
+ # not utilize it yet.
315
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
316
+ return diffusers_name
317
+
318
+
319
+ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
320
+ """
321
+ Gets the correct alpha name for the Diffusers model.
322
+ """
323
+ if lora_name_alpha.startswith("lora_unet_"):
324
+ prefix = "unet."
325
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
326
+ prefix = "text_encoder."
327
+ else:
328
+ prefix = "text_encoder_2."
329
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
330
+ return {new_name: alpha}
331
+
332
+
333
+ # The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334
+ # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335
+ # All credits go to `kohya-ss`.
336
+ def _convert_kohya_flux_lora_to_diffusers(state_dict):
337
+ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338
+ if sds_key + ".lora_down.weight" not in sds_sd:
339
+ return
340
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
341
+
342
+ # scale weight by alpha and dim
343
+ rank = down_weight.shape[0]
344
+ alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
345
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
+
347
+ # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
348
+ scale_down = scale
349
+ scale_up = 1.0
350
+ while scale_down * 2 < scale_up:
351
+ scale_down *= 2
352
+ scale_up /= 2
353
+
354
+ ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
355
+ ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
356
+
357
+ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
358
+ if sds_key + ".lora_down.weight" not in sds_sd:
359
+ return
360
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
361
+ up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
362
+ sd_lora_rank = down_weight.shape[0]
363
+
364
+ # scale weight by alpha and dim
365
+ alpha = sds_sd.pop(sds_key + ".alpha")
366
+ scale = alpha / sd_lora_rank
367
+
368
+ # calculate scale_down and scale_up
369
+ scale_down = scale
370
+ scale_up = 1.0
371
+ while scale_down * 2 < scale_up:
372
+ scale_down *= 2
373
+ scale_up /= 2
374
+
375
+ down_weight = down_weight * scale_down
376
+ up_weight = up_weight * scale_up
377
+
378
+ # calculate dims if not provided
379
+ num_splits = len(ait_keys)
380
+ if dims is None:
381
+ dims = [up_weight.shape[0] // num_splits] * num_splits
382
+ else:
383
+ assert sum(dims) == up_weight.shape[0]
384
+
385
+ # check upweight is sparse or not
386
+ is_sparse = False
387
+ if sd_lora_rank % num_splits == 0:
388
+ ait_rank = sd_lora_rank // num_splits
389
+ is_sparse = True
390
+ i = 0
391
+ for j in range(len(dims)):
392
+ for k in range(len(dims)):
393
+ if j == k:
394
+ continue
395
+ is_sparse = is_sparse and torch.all(
396
+ up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
397
+ )
398
+ i += dims[j]
399
+ if is_sparse:
400
+ logger.info(f"weight is sparse: {sds_key}")
401
+
402
+ # make ai-toolkit weight
403
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
404
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
405
+ if not is_sparse:
406
+ # down_weight is copied to each split
407
+ ait_sd.update({k: down_weight for k in ait_down_keys})
408
+
409
+ # up_weight is split to each split
410
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
411
+ else:
412
+ # down_weight is chunked to each split
413
+ ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
414
+
415
+ # up_weight is sparse: only non-zero values are copied to each split
416
+ i = 0
417
+ for j in range(len(dims)):
418
+ ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
419
+ i += dims[j]
420
+
421
+ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
422
+ ait_sd = {}
423
+ for i in range(19):
424
+ _convert_to_ai_toolkit(
425
+ sds_sd,
426
+ ait_sd,
427
+ f"lora_unet_double_blocks_{i}_img_attn_proj",
428
+ f"transformer.transformer_blocks.{i}.attn.to_out.0",
429
+ )
430
+ _convert_to_ai_toolkit_cat(
431
+ sds_sd,
432
+ ait_sd,
433
+ f"lora_unet_double_blocks_{i}_img_attn_qkv",
434
+ [
435
+ f"transformer.transformer_blocks.{i}.attn.to_q",
436
+ f"transformer.transformer_blocks.{i}.attn.to_k",
437
+ f"transformer.transformer_blocks.{i}.attn.to_v",
438
+ ],
439
+ )
440
+ _convert_to_ai_toolkit(
441
+ sds_sd,
442
+ ait_sd,
443
+ f"lora_unet_double_blocks_{i}_img_mlp_0",
444
+ f"transformer.transformer_blocks.{i}.ff.net.0.proj",
445
+ )
446
+ _convert_to_ai_toolkit(
447
+ sds_sd,
448
+ ait_sd,
449
+ f"lora_unet_double_blocks_{i}_img_mlp_2",
450
+ f"transformer.transformer_blocks.{i}.ff.net.2",
451
+ )
452
+ _convert_to_ai_toolkit(
453
+ sds_sd,
454
+ ait_sd,
455
+ f"lora_unet_double_blocks_{i}_img_mod_lin",
456
+ f"transformer.transformer_blocks.{i}.norm1.linear",
457
+ )
458
+ _convert_to_ai_toolkit(
459
+ sds_sd,
460
+ ait_sd,
461
+ f"lora_unet_double_blocks_{i}_txt_attn_proj",
462
+ f"transformer.transformer_blocks.{i}.attn.to_add_out",
463
+ )
464
+ _convert_to_ai_toolkit_cat(
465
+ sds_sd,
466
+ ait_sd,
467
+ f"lora_unet_double_blocks_{i}_txt_attn_qkv",
468
+ [
469
+ f"transformer.transformer_blocks.{i}.attn.add_q_proj",
470
+ f"transformer.transformer_blocks.{i}.attn.add_k_proj",
471
+ f"transformer.transformer_blocks.{i}.attn.add_v_proj",
472
+ ],
473
+ )
474
+ _convert_to_ai_toolkit(
475
+ sds_sd,
476
+ ait_sd,
477
+ f"lora_unet_double_blocks_{i}_txt_mlp_0",
478
+ f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
479
+ )
480
+ _convert_to_ai_toolkit(
481
+ sds_sd,
482
+ ait_sd,
483
+ f"lora_unet_double_blocks_{i}_txt_mlp_2",
484
+ f"transformer.transformer_blocks.{i}.ff_context.net.2",
485
+ )
486
+ _convert_to_ai_toolkit(
487
+ sds_sd,
488
+ ait_sd,
489
+ f"lora_unet_double_blocks_{i}_txt_mod_lin",
490
+ f"transformer.transformer_blocks.{i}.norm1_context.linear",
491
+ )
492
+
493
+ for i in range(38):
494
+ _convert_to_ai_toolkit_cat(
495
+ sds_sd,
496
+ ait_sd,
497
+ f"lora_unet_single_blocks_{i}_linear1",
498
+ [
499
+ f"transformer.single_transformer_blocks.{i}.attn.to_q",
500
+ f"transformer.single_transformer_blocks.{i}.attn.to_k",
501
+ f"transformer.single_transformer_blocks.{i}.attn.to_v",
502
+ f"transformer.single_transformer_blocks.{i}.proj_mlp",
503
+ ],
504
+ dims=[3072, 3072, 3072, 12288],
505
+ )
506
+ _convert_to_ai_toolkit(
507
+ sds_sd,
508
+ ait_sd,
509
+ f"lora_unet_single_blocks_{i}_linear2",
510
+ f"transformer.single_transformer_blocks.{i}.proj_out",
511
+ )
512
+ _convert_to_ai_toolkit(
513
+ sds_sd,
514
+ ait_sd,
515
+ f"lora_unet_single_blocks_{i}_modulation_lin",
516
+ f"transformer.single_transformer_blocks.{i}.norm.linear",
517
+ )
518
+
519
+ remaining_keys = list(sds_sd.keys())
520
+ te_state_dict = {}
521
+ if remaining_keys:
522
+ if not all(k.startswith("lora_te1") for k in remaining_keys):
523
+ raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524
+ for key in remaining_keys:
525
+ if not key.endswith("lora_down.weight"):
526
+ continue
527
+
528
+ lora_name = key.split(".")[0]
529
+ lora_name_up = f"{lora_name}.lora_up.weight"
530
+ lora_name_alpha = f"{lora_name}.alpha"
531
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
532
+
533
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
534
+ down_weight = sds_sd.pop(key)
535
+ sd_lora_rank = down_weight.shape[0]
536
+ te_state_dict[diffusers_name] = down_weight
537
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
538
+
539
+ if lora_name_alpha in sds_sd:
540
+ alpha = sds_sd.pop(lora_name_alpha).item()
541
+ scale = alpha / sd_lora_rank
542
+
543
+ scale_down = scale
544
+ scale_up = 1.0
545
+ while scale_down * 2 < scale_up:
546
+ scale_down *= 2
547
+ scale_up /= 2
548
+
549
+ te_state_dict[diffusers_name] *= scale_down
550
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
551
+
552
+ if len(sds_sd) > 0:
553
+ logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
554
+
555
+ if te_state_dict:
556
+ te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
557
+
558
+ new_state_dict = {**ait_sd, **te_state_dict}
559
+ return new_state_dict
560
+
561
+ return _convert_sd_scripts_to_ai_toolkit(state_dict)
562
+
563
+
564
+ # Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
565
+ # Some utilities were reused from
566
+ # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
567
+ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
568
+ new_state_dict = {}
569
+ orig_keys = list(old_state_dict.keys())
570
+
571
+ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
572
+ down_weight = sds_sd.pop(sds_key)
573
+ up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
574
+
575
+ # calculate dims if not provided
576
+ num_splits = len(ait_keys)
577
+ if dims is None:
578
+ dims = [up_weight.shape[0] // num_splits] * num_splits
579
+ else:
580
+ assert sum(dims) == up_weight.shape[0]
581
+
582
+ # make ai-toolkit weight
583
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
584
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
585
+
586
+ # down_weight is copied to each split
587
+ ait_sd.update({k: down_weight for k in ait_down_keys})
588
+
589
+ # up_weight is split to each split
590
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
591
+
592
+ for old_key in orig_keys:
593
+ # Handle double_blocks
594
+ if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
595
+ block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
596
+ new_key = f"transformer.transformer_blocks.{block_num}"
597
+
598
+ if "processor.proj_lora1" in old_key:
599
+ new_key += ".attn.to_out.0"
600
+ elif "processor.proj_lora2" in old_key:
601
+ new_key += ".attn.to_add_out"
602
+ # Handle text latents.
603
+ elif "processor.qkv_lora2" in old_key and "up" not in old_key:
604
+ handle_qkv(
605
+ old_state_dict,
606
+ new_state_dict,
607
+ old_key,
608
+ [
609
+ f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
610
+ f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
611
+ f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
612
+ ],
613
+ )
614
+ # continue
615
+ # Handle image latents.
616
+ elif "processor.qkv_lora1" in old_key and "up" not in old_key:
617
+ handle_qkv(
618
+ old_state_dict,
619
+ new_state_dict,
620
+ old_key,
621
+ [
622
+ f"transformer.transformer_blocks.{block_num}.attn.to_q",
623
+ f"transformer.transformer_blocks.{block_num}.attn.to_k",
624
+ f"transformer.transformer_blocks.{block_num}.attn.to_v",
625
+ ],
626
+ )
627
+ # continue
628
+
629
+ if "down" in old_key:
630
+ new_key += ".lora_A.weight"
631
+ elif "up" in old_key:
632
+ new_key += ".lora_B.weight"
633
+
634
+ # Handle single_blocks
635
+ elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
636
+ block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637
+ new_key = f"transformer.single_transformer_blocks.{block_num}"
638
+
639
+ if "proj_lora" in old_key:
640
+ new_key += ".proj_out"
641
+ elif "qkv_lora" in old_key and "up" not in old_key:
642
+ handle_qkv(
643
+ old_state_dict,
644
+ new_state_dict,
645
+ old_key,
646
+ [
647
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
648
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
649
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
650
+ ],
651
+ )
652
+
653
+ if "down" in old_key:
654
+ new_key += ".lora_A.weight"
655
+ elif "up" in old_key:
656
+ new_key += ".lora_B.weight"
657
+
658
+ else:
659
+ # Handle other potential key patterns here
660
+ new_key = old_key
661
+
662
+ # Since we already handle qkv above.
663
+ if "qkv" not in old_key:
664
+ new_state_dict[new_key] = old_state_dict.pop(old_key)
665
+
666
+ if len(old_state_dict) > 0:
667
+ raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
668
+
669
+ return new_state_dict
670
+
671
+
672
+ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
673
+ converted_state_dict = {}
674
+ original_state_dict_keys = list(original_state_dict.keys())
675
+ num_layers = 19
676
+ num_single_layers = 38
677
+ inner_dim = 3072
678
+ mlp_ratio = 4.0
679
+
680
+ def swap_scale_shift(weight):
681
+ shift, scale = weight.chunk(2, dim=0)
682
+ new_weight = torch.cat([scale, shift], dim=0)
683
+ return new_weight
684
+
685
+ for lora_key in ["lora_A", "lora_B"]:
686
+ ## time_text_embed.timestep_embedder <- time_in
687
+ converted_state_dict[
688
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
689
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
690
+ if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
691
+ converted_state_dict[
692
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
693
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
694
+
695
+ converted_state_dict[
696
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
697
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
698
+ if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
699
+ converted_state_dict[
700
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
701
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
702
+
703
+ ## time_text_embed.text_embedder <- vector_in
704
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
705
+ f"vector_in.in_layer.{lora_key}.weight"
706
+ )
707
+ if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
708
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
709
+ f"vector_in.in_layer.{lora_key}.bias"
710
+ )
711
+
712
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
713
+ f"vector_in.out_layer.{lora_key}.weight"
714
+ )
715
+ if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
716
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
717
+ f"vector_in.out_layer.{lora_key}.bias"
718
+ )
719
+
720
+ # guidance
721
+ has_guidance = any("guidance" in k for k in original_state_dict)
722
+ if has_guidance:
723
+ converted_state_dict[
724
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
725
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
726
+ if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
727
+ converted_state_dict[
728
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
729
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
730
+
731
+ converted_state_dict[
732
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
733
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
734
+ if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
735
+ converted_state_dict[
736
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
737
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
738
+
739
+ # context_embedder
740
+ converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
741
+ f"txt_in.{lora_key}.weight"
742
+ )
743
+ if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
744
+ converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
745
+ f"txt_in.{lora_key}.bias"
746
+ )
747
+
748
+ # x_embedder
749
+ converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
750
+ if f"img_in.{lora_key}.bias" in original_state_dict_keys:
751
+ converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
752
+
753
+ # double transformer blocks
754
+ for i in range(num_layers):
755
+ block_prefix = f"transformer_blocks.{i}."
756
+
757
+ for lora_key in ["lora_A", "lora_B"]:
758
+ # norms
759
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
760
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
761
+ )
762
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
763
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
764
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
765
+ )
766
+
767
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
768
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
769
+ )
770
+ if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
771
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
772
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
773
+ )
774
+
775
+ # Q, K, V
776
+ if lora_key == "lora_A":
777
+ sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
778
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
779
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
780
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
781
+
782
+ context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
783
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
784
+ [context_lora_weight]
785
+ )
786
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
787
+ [context_lora_weight]
788
+ )
789
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
790
+ [context_lora_weight]
791
+ )
792
+ else:
793
+ sample_q, sample_k, sample_v = torch.chunk(
794
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
795
+ )
796
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
797
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
798
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
799
+
800
+ context_q, context_k, context_v = torch.chunk(
801
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
802
+ )
803
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
804
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
805
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
806
+
807
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
808
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
809
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
810
+ )
811
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
812
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
813
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
814
+
815
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
816
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
817
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
818
+ )
819
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
820
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
821
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
822
+
823
+ # ff img_mlp
824
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
825
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
826
+ )
827
+ if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
828
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
829
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
830
+ )
831
+
832
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
833
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
834
+ )
835
+ if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
836
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
837
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
838
+ )
839
+
840
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
841
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
842
+ )
843
+ if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
844
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
845
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
846
+ )
847
+
848
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
849
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
850
+ )
851
+ if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
852
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
853
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
854
+ )
855
+
856
+ # output projections.
857
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
858
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
859
+ )
860
+ if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
861
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
862
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
863
+ )
864
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
865
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
866
+ )
867
+ if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
868
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
869
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
870
+ )
871
+
872
+ # qk_norm
873
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
874
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
875
+ )
876
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
877
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
878
+ )
879
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
880
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
881
+ )
882
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
883
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
884
+ )
885
+
886
+ # single transfomer blocks
887
+ for i in range(num_single_layers):
888
+ block_prefix = f"single_transformer_blocks.{i}."
889
+
890
+ for lora_key in ["lora_A", "lora_B"]:
891
+ # norm.linear <- single_blocks.0.modulation.lin
892
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
893
+ f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
894
+ )
895
+ if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
896
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
897
+ f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
898
+ )
899
+
900
+ # Q, K, V, mlp
901
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
902
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
903
+
904
+ if lora_key == "lora_A":
905
+ lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
906
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
907
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
908
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
909
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
910
+
911
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
912
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
913
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
914
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
915
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
916
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
917
+ else:
918
+ q, k, v, mlp = torch.split(
919
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
920
+ )
921
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
922
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
923
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
924
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
925
+
926
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
927
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
928
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
929
+ )
930
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
931
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
932
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
933
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
934
+
935
+ # output projections.
936
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
937
+ f"single_blocks.{i}.linear2.{lora_key}.weight"
938
+ )
939
+ if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
940
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
941
+ f"single_blocks.{i}.linear2.{lora_key}.bias"
942
+ )
943
+
944
+ # qk norm
945
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
946
+ f"single_blocks.{i}.norm.query_norm.scale"
947
+ )
948
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
949
+ f"single_blocks.{i}.norm.key_norm.scale"
950
+ )
951
+
952
+ for lora_key in ["lora_A", "lora_B"]:
953
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
954
+ f"final_layer.linear.{lora_key}.weight"
955
+ )
956
+ if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
957
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
958
+ f"final_layer.linear.{lora_key}.bias"
959
+ )
960
+
961
+ converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
962
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
963
+ )
964
+ if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
965
+ converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
966
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
967
+ )
968
+
969
+ if len(original_state_dict) > 0:
970
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
971
+
972
+ for key in list(converted_state_dict.keys()):
973
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
974
+
975
+ return converted_state_dict
976
+
977
+
978
+ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
979
+ converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
980
+
981
+ def remap_norm_scale_shift_(key, state_dict):
982
+ weight = state_dict.pop(key)
983
+ shift, scale = weight.chunk(2, dim=0)
984
+ new_weight = torch.cat([scale, shift], dim=0)
985
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
986
+
987
+ def remap_txt_in_(key, state_dict):
988
+ def rename_key(key):
989
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
990
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
991
+ new_key = new_key.replace("txt_in", "context_embedder")
992
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
993
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
994
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
995
+ new_key = new_key.replace("mlp", "ff")
996
+ return new_key
997
+
998
+ if "self_attn_qkv" in key:
999
+ weight = state_dict.pop(key)
1000
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1001
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
1002
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
1003
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
1004
+ else:
1005
+ state_dict[rename_key(key)] = state_dict.pop(key)
1006
+
1007
+ def remap_img_attn_qkv_(key, state_dict):
1008
+ weight = state_dict.pop(key)
1009
+ if "lora_A" in key:
1010
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
1011
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
1012
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
1013
+ else:
1014
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1015
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
1016
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
1017
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
1018
+
1019
+ def remap_txt_attn_qkv_(key, state_dict):
1020
+ weight = state_dict.pop(key)
1021
+ if "lora_A" in key:
1022
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
1023
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
1024
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
1025
+ else:
1026
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
1027
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
1028
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
1029
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
1030
+
1031
+ def remap_single_transformer_blocks_(key, state_dict):
1032
+ hidden_size = 3072
1033
+
1034
+ if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
1035
+ linear1_weight = state_dict.pop(key)
1036
+ if "lora_A" in key:
1037
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1038
+ ".linear1.lora_A.weight"
1039
+ )
1040
+ state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
1041
+ state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
1042
+ state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
1043
+ state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
1044
+ else:
1045
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
1046
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
1047
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1048
+ ".linear1.lora_B.weight"
1049
+ )
1050
+ state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
1051
+ state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
1052
+ state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
1053
+ state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
1054
+
1055
+ elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
1056
+ linear1_bias = state_dict.pop(key)
1057
+ if "lora_A" in key:
1058
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1059
+ ".linear1.lora_A.bias"
1060
+ )
1061
+ state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
1062
+ state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
1063
+ state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
1064
+ state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
1065
+ else:
1066
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
1067
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
1068
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1069
+ ".linear1.lora_B.bias"
1070
+ )
1071
+ state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
1072
+ state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
1073
+ state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
1074
+ state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
1075
+
1076
+ else:
1077
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
1078
+ new_key = new_key.replace("linear2", "proj_out")
1079
+ new_key = new_key.replace("q_norm", "attn.norm_q")
1080
+ new_key = new_key.replace("k_norm", "attn.norm_k")
1081
+ state_dict[new_key] = state_dict.pop(key)
1082
+
1083
+ TRANSFORMER_KEYS_RENAME_DICT = {
1084
+ "img_in": "x_embedder",
1085
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
1086
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
1087
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
1088
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
1089
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
1090
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
1091
+ "double_blocks": "transformer_blocks",
1092
+ "img_attn_q_norm": "attn.norm_q",
1093
+ "img_attn_k_norm": "attn.norm_k",
1094
+ "img_attn_proj": "attn.to_out.0",
1095
+ "txt_attn_q_norm": "attn.norm_added_q",
1096
+ "txt_attn_k_norm": "attn.norm_added_k",
1097
+ "txt_attn_proj": "attn.to_add_out",
1098
+ "img_mod.linear": "norm1.linear",
1099
+ "img_norm1": "norm1.norm",
1100
+ "img_norm2": "norm2",
1101
+ "img_mlp": "ff",
1102
+ "txt_mod.linear": "norm1_context.linear",
1103
+ "txt_norm1": "norm1.norm",
1104
+ "txt_norm2": "norm2_context",
1105
+ "txt_mlp": "ff_context",
1106
+ "self_attn_proj": "attn.to_out.0",
1107
+ "modulation.linear": "norm.linear",
1108
+ "pre_norm": "norm.norm",
1109
+ "final_layer.norm_final": "norm_out.norm",
1110
+ "final_layer.linear": "proj_out",
1111
+ "fc1": "net.0.proj",
1112
+ "fc2": "net.2",
1113
+ "input_embedder": "proj_in",
1114
+ }
1115
+
1116
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
1117
+ "txt_in": remap_txt_in_,
1118
+ "img_attn_qkv": remap_img_attn_qkv_,
1119
+ "txt_attn_qkv": remap_txt_attn_qkv_,
1120
+ "single_blocks": remap_single_transformer_blocks_,
1121
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
1122
+ }
1123
+
1124
+ # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
1125
+ # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
1126
+ # sure that both follow the same initial format by stripping off the "transformer." prefix.
1127
+ for key in list(converted_state_dict.keys()):
1128
+ if key.startswith("transformer."):
1129
+ converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
1130
+ if key.startswith("diffusion_model."):
1131
+ converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
1132
+
1133
+ # Rename and remap the state dict keys
1134
+ for key in list(converted_state_dict.keys()):
1135
+ new_key = key[:]
1136
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
1137
+ new_key = new_key.replace(replace_key, rename_key)
1138
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
1139
+
1140
+ for key in list(converted_state_dict.keys()):
1141
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
1142
+ if special_key not in key:
1143
+ continue
1144
+ handler_fn_inplace(key, converted_state_dict)
1145
+
1146
+ # Add back the "transformer." prefix
1147
+ for key in list(converted_state_dict.keys()):
1148
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1149
+
1150
+ return converted_state_dict
icedit/diffusers/loaders/lora_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
icedit/diffusers/loaders/peft.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import inspect
16
+ import os
17
+ from functools import partial
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import safetensors
22
+ import torch
23
+
24
+ from ..utils import (
25
+ MIN_PEFT_VERSION,
26
+ USE_PEFT_BACKEND,
27
+ check_peft_version,
28
+ convert_unet_state_dict_to_peft,
29
+ delete_adapter_layers,
30
+ get_adapter_name,
31
+ get_peft_kwargs,
32
+ is_peft_available,
33
+ is_peft_version,
34
+ logging,
35
+ set_adapter_layers,
36
+ set_weights_and_activate_adapters,
37
+ )
38
+ from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
39
+ from .unet_loader_utils import _maybe_expand_lora_scales
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ _SET_ADAPTER_SCALE_FN_MAPPING = {
45
+ "UNet2DConditionModel": _maybe_expand_lora_scales,
46
+ "UNetMotionModel": _maybe_expand_lora_scales,
47
+ "SD3Transformer2DModel": lambda model_cls, weights: weights,
48
+ "FluxTransformer2DModel": lambda model_cls, weights: weights,
49
+ "CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
50
+ "MochiTransformer3DModel": lambda model_cls, weights: weights,
51
+ "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
52
+ "LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
53
+ "SanaTransformer2DModel": lambda model_cls, weights: weights,
54
+ }
55
+
56
+
57
+ def _maybe_adjust_config(config):
58
+ """
59
+ We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
60
+ (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
61
+ method removes the ambiguity by following what is described here:
62
+ https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
63
+ """
64
+ rank_pattern = config["rank_pattern"].copy()
65
+ target_modules = config["target_modules"]
66
+ original_r = config["r"]
67
+
68
+ for key in list(rank_pattern.keys()):
69
+ key_rank = rank_pattern[key]
70
+
71
+ # try to detect ambiguity
72
+ # `target_modules` can also be a str, in which case this loop would loop
73
+ # over the chars of the str. The technically correct way to match LoRA keys
74
+ # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
75
+ # But this cuts it for now.
76
+ exact_matches = [mod for mod in target_modules if mod == key]
77
+ substring_matches = [mod for mod in target_modules if key in mod and mod != key]
78
+ ambiguous_key = key
79
+
80
+ if exact_matches and substring_matches:
81
+ # if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
82
+ config["r"] = key_rank
83
+ # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
84
+ del config["rank_pattern"][key]
85
+ for mod in substring_matches:
86
+ # avoid overwriting if the module already has a specific rank
87
+ if mod not in config["rank_pattern"]:
88
+ config["rank_pattern"][mod] = original_r
89
+
90
+ # update the rest of the keys with the `original_r`
91
+ for mod in target_modules:
92
+ if mod != ambiguous_key and mod not in config["rank_pattern"]:
93
+ config["rank_pattern"][mod] = original_r
94
+
95
+ # handle alphas to deal with cases like
96
+ # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
97
+ has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
98
+ if has_different_ranks:
99
+ config["lora_alpha"] = config["r"]
100
+ alpha_pattern = {}
101
+ for module_name, rank in config["rank_pattern"].items():
102
+ alpha_pattern[module_name] = rank
103
+ config["alpha_pattern"] = alpha_pattern
104
+
105
+ return config
106
+
107
+
108
+ class PeftAdapterMixin:
109
+ """
110
+ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
111
+ more details about adapters and injecting them in a base model, check out the PEFT
112
+ [documentation](https://huggingface.co/docs/peft/index).
113
+
114
+ Install the latest version of PEFT, and use this mixin to:
115
+
116
+ - Attach new adapters in the model.
117
+ - Attach multiple adapters and iteratively activate/deactivate them.
118
+ - Activate/deactivate all adapters from the model.
119
+ - Get a list of the active adapters.
120
+ """
121
+
122
+ _hf_peft_config_loaded = False
123
+
124
+ @classmethod
125
+ # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
126
+ def _optionally_disable_offloading(cls, _pipeline):
127
+ """
128
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
129
+
130
+ Args:
131
+ _pipeline (`DiffusionPipeline`):
132
+ The pipeline to disable offloading for.
133
+
134
+ Returns:
135
+ tuple:
136
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
137
+ """
138
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
139
+
140
+ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
141
+ r"""
142
+ Loads a LoRA adapter into the underlying model.
143
+
144
+ Parameters:
145
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
146
+ Can be either:
147
+
148
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
149
+ the Hub.
150
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
151
+ with [`ModelMixin.save_pretrained`].
152
+ - A [torch state
153
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
154
+
155
+ prefix (`str`, *optional*): Prefix to filter the state dict.
156
+
157
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
158
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
159
+ is not used.
160
+ force_download (`bool`, *optional*, defaults to `False`):
161
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
162
+ cached versions if they exist.
163
+ proxies (`Dict[str, str]`, *optional*):
164
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
165
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
166
+ local_files_only (`bool`, *optional*, defaults to `False`):
167
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
168
+ won't be downloaded from the Hub.
169
+ token (`str` or *bool*, *optional*):
170
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
171
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
172
+ revision (`str`, *optional*, defaults to `"main"`):
173
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
174
+ allowed by Git.
175
+ subfolder (`str`, *optional*, defaults to `""`):
176
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
177
+ network_alphas (`Dict[str, float]`):
178
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
179
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
180
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
181
+ low_cpu_mem_usage (`bool`, *optional*):
182
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
183
+ weights.
184
+ """
185
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
186
+ from peft.tuners.tuners_utils import BaseTunerLayer
187
+
188
+ cache_dir = kwargs.pop("cache_dir", None)
189
+ force_download = kwargs.pop("force_download", False)
190
+ proxies = kwargs.pop("proxies", None)
191
+ local_files_only = kwargs.pop("local_files_only", None)
192
+ token = kwargs.pop("token", None)
193
+ revision = kwargs.pop("revision", None)
194
+ subfolder = kwargs.pop("subfolder", None)
195
+ weight_name = kwargs.pop("weight_name", None)
196
+ use_safetensors = kwargs.pop("use_safetensors", None)
197
+ adapter_name = kwargs.pop("adapter_name", None)
198
+ network_alphas = kwargs.pop("network_alphas", None)
199
+ _pipeline = kwargs.pop("_pipeline", None)
200
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
201
+ allow_pickle = False
202
+
203
+ if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
204
+ raise ValueError(
205
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
206
+ )
207
+
208
+ user_agent = {
209
+ "file_type": "attn_procs_weights",
210
+ "framework": "pytorch",
211
+ }
212
+
213
+ state_dict = _fetch_state_dict(
214
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
215
+ weight_name=weight_name,
216
+ use_safetensors=use_safetensors,
217
+ local_files_only=local_files_only,
218
+ cache_dir=cache_dir,
219
+ force_download=force_download,
220
+ proxies=proxies,
221
+ token=token,
222
+ revision=revision,
223
+ subfolder=subfolder,
224
+ user_agent=user_agent,
225
+ allow_pickle=allow_pickle,
226
+ )
227
+ if network_alphas is not None and prefix is None:
228
+ raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
229
+
230
+ if prefix is not None:
231
+ keys = list(state_dict.keys())
232
+ model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
233
+ if len(model_keys) > 0:
234
+ state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
235
+
236
+ if len(state_dict) > 0:
237
+ if adapter_name in getattr(self, "peft_config", {}):
238
+ raise ValueError(
239
+ f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
240
+ )
241
+
242
+ # check with first key if is not in peft format
243
+ first_key = next(iter(state_dict.keys()))
244
+ if "lora_A" not in first_key:
245
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
246
+
247
+ rank = {}
248
+ for key, val in state_dict.items():
249
+ # Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
250
+ # Bias layers in LoRA only have a single dimension
251
+ if "lora_B" in key and val.ndim > 1:
252
+ rank[key] = val.shape[1]
253
+
254
+ if network_alphas is not None and len(network_alphas) >= 1:
255
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
256
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
257
+
258
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
259
+ # lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) # TODO: remove this for moe
260
+
261
+ if "use_dora" in lora_config_kwargs:
262
+ if lora_config_kwargs["use_dora"]:
263
+ if is_peft_version("<", "0.9.0"):
264
+ raise ValueError(
265
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
266
+ )
267
+ else:
268
+ if is_peft_version("<", "0.9.0"):
269
+ lora_config_kwargs.pop("use_dora")
270
+
271
+ if "lora_bias" in lora_config_kwargs:
272
+ if lora_config_kwargs["lora_bias"]:
273
+ if is_peft_version("<=", "0.13.2"):
274
+ raise ValueError(
275
+ "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
276
+ )
277
+ else:
278
+ if is_peft_version("<=", "0.13.2"):
279
+ lora_config_kwargs.pop("lora_bias")
280
+
281
+ lora_config = LoraConfig(**lora_config_kwargs)
282
+ # adapter_name
283
+ if adapter_name is None:
284
+ adapter_name = get_adapter_name(self)
285
+
286
+ # <Unsafe code
287
+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
288
+ # Now we remove any existing hooks to `_pipeline`.
289
+
290
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
291
+ # otherwise loading LoRA weights will lead to an error
292
+ is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
293
+
294
+ peft_kwargs = {}
295
+ if is_peft_version(">=", "0.13.1"):
296
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
297
+
298
+ # To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
299
+ # we should also delete the `peft_config` associated to the `adapter_name`.
300
+ try:
301
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
302
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
303
+ except RuntimeError as e:
304
+ for module in self.modules():
305
+ if isinstance(module, BaseTunerLayer):
306
+ active_adapters = module.active_adapters
307
+ for active_adapter in active_adapters:
308
+ if adapter_name in active_adapter:
309
+ module.delete_adapter(adapter_name)
310
+
311
+ self.peft_config.pop(adapter_name)
312
+ logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
313
+ raise
314
+
315
+ warn_msg = ""
316
+ if incompatible_keys is not None:
317
+ # Check only for unexpected keys.
318
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
319
+ if unexpected_keys:
320
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
321
+ if lora_unexpected_keys:
322
+ warn_msg = (
323
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
324
+ f" {', '.join(lora_unexpected_keys)}. "
325
+ )
326
+
327
+ # Filter missing keys specific to the current adapter.
328
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
329
+ if missing_keys:
330
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
331
+ if lora_missing_keys:
332
+ warn_msg += (
333
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
334
+ f" {', '.join(lora_missing_keys)}."
335
+ )
336
+
337
+ if warn_msg:
338
+ logger.warning(warn_msg)
339
+
340
+ # Offload back.
341
+ if is_model_cpu_offload:
342
+ _pipeline.enable_model_cpu_offload()
343
+ elif is_sequential_cpu_offload:
344
+ _pipeline.enable_sequential_cpu_offload()
345
+ # Unsafe code />
346
+
347
+ def save_lora_adapter(
348
+ self,
349
+ save_directory,
350
+ adapter_name: str = "default",
351
+ upcast_before_saving: bool = False,
352
+ safe_serialization: bool = True,
353
+ weight_name: Optional[str] = None,
354
+ ):
355
+ """
356
+ Save the LoRA parameters corresponding to the underlying model.
357
+
358
+ Arguments:
359
+ save_directory (`str` or `os.PathLike`):
360
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
361
+ adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
362
+ underlying model has multiple adapters loaded.
363
+ upcast_before_saving (`bool`, defaults to `False`):
364
+ Whether to cast the underlying model to `torch.float32` before serialization.
365
+ save_function (`Callable`):
366
+ The function to use to save the state dictionary. Useful during distributed training when you need to
367
+ replace `torch.save` with another method. Can be configured with the environment variable
368
+ `DIFFUSERS_SAVE_MODE`.
369
+ safe_serialization (`bool`, *optional*, defaults to `True`):
370
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
371
+ weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
372
+ """
373
+ from peft.utils import get_peft_model_state_dict
374
+
375
+ from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
376
+
377
+ if adapter_name is None:
378
+ adapter_name = get_adapter_name(self)
379
+
380
+ if adapter_name not in getattr(self, "peft_config", {}):
381
+ raise ValueError(f"Adapter name {adapter_name} not found in the model.")
382
+
383
+ lora_layers_to_save = get_peft_model_state_dict(
384
+ self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
385
+ )
386
+ if os.path.isfile(save_directory):
387
+ raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
388
+
389
+ if safe_serialization:
390
+
391
+ def save_function(weights, filename):
392
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
393
+
394
+ else:
395
+ save_function = torch.save
396
+
397
+ os.makedirs(save_directory, exist_ok=True)
398
+
399
+ if weight_name is None:
400
+ if safe_serialization:
401
+ weight_name = LORA_WEIGHT_NAME_SAFE
402
+ else:
403
+ weight_name = LORA_WEIGHT_NAME
404
+
405
+ # TODO: we could consider saving the `peft_config` as well.
406
+ save_path = Path(save_directory, weight_name).as_posix()
407
+ save_function(lora_layers_to_save, save_path)
408
+ logger.info(f"Model weights saved in {save_path}")
409
+
410
+ def set_adapters(
411
+ self,
412
+ adapter_names: Union[List[str], str],
413
+ weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
414
+ ):
415
+ """
416
+ Set the currently active adapters for use in the UNet.
417
+
418
+ Args:
419
+ adapter_names (`List[str]` or `str`):
420
+ The names of the adapters to use.
421
+ adapter_weights (`Union[List[float], float]`, *optional*):
422
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
423
+ adapters.
424
+
425
+ Example:
426
+
427
+ ```py
428
+ from diffusers import AutoPipelineForText2Image
429
+ import torch
430
+
431
+ pipeline = AutoPipelineForText2Image.from_pretrained(
432
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
433
+ ).to("cuda")
434
+ pipeline.load_lora_weights(
435
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
436
+ )
437
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
438
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
439
+ ```
440
+ """
441
+ if not USE_PEFT_BACKEND:
442
+ raise ValueError("PEFT backend is required for `set_adapters()`.")
443
+
444
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
445
+
446
+ # Expand weights into a list, one entry per adapter
447
+ # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
448
+ if not isinstance(weights, list):
449
+ weights = [weights] * len(adapter_names)
450
+
451
+ if len(adapter_names) != len(weights):
452
+ raise ValueError(
453
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
454
+ )
455
+
456
+ # Set None values to default of 1.0
457
+ # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
458
+ weights = [w if w is not None else 1.0 for w in weights]
459
+
460
+ # e.g. [{...}, 7] -> [{expanded dict...}, 7]
461
+ scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
462
+ weights = scale_expansion_fn(self, weights)
463
+
464
+ set_weights_and_activate_adapters(self, adapter_names, weights)
465
+
466
+ def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
467
+ r"""
468
+ Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
469
+ to the adapter to follow the convention of the PEFT library.
470
+
471
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
472
+ [documentation](https://huggingface.co/docs/peft).
473
+
474
+ Args:
475
+ adapter_config (`[~peft.PeftConfig]`):
476
+ The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
477
+ methods.
478
+ adapter_name (`str`, *optional*, defaults to `"default"`):
479
+ The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
480
+ """
481
+ check_peft_version(min_version=MIN_PEFT_VERSION)
482
+
483
+ if not is_peft_available():
484
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
485
+
486
+ from peft import PeftConfig, inject_adapter_in_model
487
+
488
+ if not self._hf_peft_config_loaded:
489
+ self._hf_peft_config_loaded = True
490
+ elif adapter_name in self.peft_config:
491
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
492
+
493
+ if not isinstance(adapter_config, PeftConfig):
494
+ raise ValueError(
495
+ f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
496
+ )
497
+
498
+ # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
499
+ # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here.
500
+ adapter_config.base_model_name_or_path = None
501
+ inject_adapter_in_model(adapter_config, self, adapter_name)
502
+ self.set_adapter(adapter_name)
503
+
504
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
505
+ """
506
+ Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
507
+
508
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
509
+ [documentation](https://huggingface.co/docs/peft).
510
+
511
+ Args:
512
+ adapter_name (Union[str, List[str]])):
513
+ The list of adapters to set or the adapter name in the case of a single adapter.
514
+ """
515
+ check_peft_version(min_version=MIN_PEFT_VERSION)
516
+
517
+ if not self._hf_peft_config_loaded:
518
+ raise ValueError("No adapter loaded. Please load an adapter first.")
519
+
520
+ if isinstance(adapter_name, str):
521
+ adapter_name = [adapter_name]
522
+
523
+ missing = set(adapter_name) - set(self.peft_config)
524
+ if len(missing) > 0:
525
+ raise ValueError(
526
+ f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
527
+ f" current loaded adapters are: {list(self.peft_config.keys())}"
528
+ )
529
+
530
+ from peft.tuners.tuners_utils import BaseTunerLayer
531
+
532
+ _adapters_has_been_set = False
533
+
534
+ for _, module in self.named_modules():
535
+ if isinstance(module, BaseTunerLayer):
536
+ if hasattr(module, "set_adapter"):
537
+ module.set_adapter(adapter_name)
538
+ # Previous versions of PEFT does not support multi-adapter inference
539
+ elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
540
+ raise ValueError(
541
+ "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
542
+ " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
543
+ )
544
+ else:
545
+ module.active_adapter = adapter_name
546
+ _adapters_has_been_set = True
547
+
548
+ if not _adapters_has_been_set:
549
+ raise ValueError(
550
+ "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
551
+ )
552
+
553
+ def disable_adapters(self) -> None:
554
+ r"""
555
+ Disable all adapters attached to the model and fallback to inference with the base model only.
556
+
557
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
558
+ [documentation](https://huggingface.co/docs/peft).
559
+ """
560
+ check_peft_version(min_version=MIN_PEFT_VERSION)
561
+
562
+ if not self._hf_peft_config_loaded:
563
+ raise ValueError("No adapter loaded. Please load an adapter first.")
564
+
565
+ from peft.tuners.tuners_utils import BaseTunerLayer
566
+
567
+ for _, module in self.named_modules():
568
+ if isinstance(module, BaseTunerLayer):
569
+ if hasattr(module, "enable_adapters"):
570
+ module.enable_adapters(enabled=False)
571
+ else:
572
+ # support for older PEFT versions
573
+ module.disable_adapters = True
574
+
575
+ def enable_adapters(self) -> None:
576
+ """
577
+ Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
578
+ adapters to enable.
579
+
580
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
581
+ [documentation](https://huggingface.co/docs/peft).
582
+ """
583
+ check_peft_version(min_version=MIN_PEFT_VERSION)
584
+
585
+ if not self._hf_peft_config_loaded:
586
+ raise ValueError("No adapter loaded. Please load an adapter first.")
587
+
588
+ from peft.tuners.tuners_utils import BaseTunerLayer
589
+
590
+ for _, module in self.named_modules():
591
+ if isinstance(module, BaseTunerLayer):
592
+ if hasattr(module, "enable_adapters"):
593
+ module.enable_adapters(enabled=True)
594
+ else:
595
+ # support for older PEFT versions
596
+ module.disable_adapters = False
597
+
598
+ def active_adapters(self) -> List[str]:
599
+ """
600
+ Gets the current list of active adapters of the model.
601
+
602
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
603
+ [documentation](https://huggingface.co/docs/peft).
604
+ """
605
+ check_peft_version(min_version=MIN_PEFT_VERSION)
606
+
607
+ if not is_peft_available():
608
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
609
+
610
+ if not self._hf_peft_config_loaded:
611
+ raise ValueError("No adapter loaded. Please load an adapter first.")
612
+
613
+ from peft.tuners.tuners_utils import BaseTunerLayer
614
+
615
+ for _, module in self.named_modules():
616
+ if isinstance(module, BaseTunerLayer):
617
+ return module.active_adapter
618
+
619
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
620
+ if not USE_PEFT_BACKEND:
621
+ raise ValueError("PEFT backend is required for `fuse_lora()`.")
622
+
623
+ self.lora_scale = lora_scale
624
+ self._safe_fusing = safe_fusing
625
+ self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
626
+
627
+ def _fuse_lora_apply(self, module, adapter_names=None):
628
+ from peft.tuners.tuners_utils import BaseTunerLayer
629
+
630
+ merge_kwargs = {"safe_merge": self._safe_fusing}
631
+
632
+ if isinstance(module, BaseTunerLayer):
633
+ if self.lora_scale != 1.0:
634
+ module.scale_layer(self.lora_scale)
635
+
636
+ # For BC with prevous PEFT versions, we need to check the signature
637
+ # of the `merge` method to see if it supports the `adapter_names` argument.
638
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
639
+ if "adapter_names" in supported_merge_kwargs:
640
+ merge_kwargs["adapter_names"] = adapter_names
641
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
642
+ raise ValueError(
643
+ "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
644
+ " to the latest version of PEFT. `pip install -U peft`"
645
+ )
646
+
647
+ module.merge(**merge_kwargs)
648
+
649
+ def unfuse_lora(self):
650
+ if not USE_PEFT_BACKEND:
651
+ raise ValueError("PEFT backend is required for `unfuse_lora()`.")
652
+ self.apply(self._unfuse_lora_apply)
653
+
654
+ def _unfuse_lora_apply(self, module):
655
+ from peft.tuners.tuners_utils import BaseTunerLayer
656
+
657
+ if isinstance(module, BaseTunerLayer):
658
+ module.unmerge()
659
+
660
+ def unload_lora(self):
661
+ if not USE_PEFT_BACKEND:
662
+ raise ValueError("PEFT backend is required for `unload_lora()`.")
663
+
664
+ from ..utils import recurse_remove_peft_layers
665
+
666
+ recurse_remove_peft_layers(self)
667
+ if hasattr(self, "peft_config"):
668
+ del self.peft_config
669
+
670
+ def disable_lora(self):
671
+ """
672
+ Disables the active LoRA layers of the underlying model.
673
+
674
+ Example:
675
+
676
+ ```py
677
+ from diffusers import AutoPipelineForText2Image
678
+ import torch
679
+
680
+ pipeline = AutoPipelineForText2Image.from_pretrained(
681
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
682
+ ).to("cuda")
683
+ pipeline.load_lora_weights(
684
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
685
+ )
686
+ pipeline.disable_lora()
687
+ ```
688
+ """
689
+ if not USE_PEFT_BACKEND:
690
+ raise ValueError("PEFT backend is required for this method.")
691
+ set_adapter_layers(self, enabled=False)
692
+
693
+ def enable_lora(self):
694
+ """
695
+ Enables the active LoRA layers of the underlying model.
696
+
697
+ Example:
698
+
699
+ ```py
700
+ from diffusers import AutoPipelineForText2Image
701
+ import torch
702
+
703
+ pipeline = AutoPipelineForText2Image.from_pretrained(
704
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
705
+ ).to("cuda")
706
+ pipeline.load_lora_weights(
707
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
708
+ )
709
+ pipeline.enable_lora()
710
+ ```
711
+ """
712
+ if not USE_PEFT_BACKEND:
713
+ raise ValueError("PEFT backend is required for this method.")
714
+ set_adapter_layers(self, enabled=True)
715
+
716
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
717
+ """
718
+ Delete an adapter's LoRA layers from the underlying model.
719
+
720
+ Args:
721
+ adapter_names (`Union[List[str], str]`):
722
+ The names (single string or list of strings) of the adapter to delete.
723
+
724
+ Example:
725
+
726
+ ```py
727
+ from diffusers import AutoPipelineForText2Image
728
+ import torch
729
+
730
+ pipeline = AutoPipelineForText2Image.from_pretrained(
731
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
732
+ ).to("cuda")
733
+ pipeline.load_lora_weights(
734
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
735
+ )
736
+ pipeline.delete_adapters("cinematic")
737
+ ```
738
+ """
739
+ if not USE_PEFT_BACKEND:
740
+ raise ValueError("PEFT backend is required for this method.")
741
+
742
+ if isinstance(adapter_names, str):
743
+ adapter_names = [adapter_names]
744
+
745
+ for adapter_name in adapter_names:
746
+ delete_adapter_layers(self, adapter_name)
747
+
748
+ # Pop also the corresponding adapter from the config
749
+ if hasattr(self, "peft_config"):
750
+ self.peft_config.pop(adapter_name, None)
icedit/diffusers/loaders/single_file.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import inspect
16
+ import os
17
+
18
+ import torch
19
+ from huggingface_hub import snapshot_download
20
+ from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
21
+ from packaging import version
22
+
23
+ from ..utils import deprecate, is_transformers_available, logging
24
+ from .single_file_utils import (
25
+ SingleFileComponentError,
26
+ _is_legacy_scheduler_kwargs,
27
+ _is_model_weights_in_cached_folder,
28
+ _legacy_load_clip_tokenizer,
29
+ _legacy_load_safety_checker,
30
+ _legacy_load_scheduler,
31
+ create_diffusers_clip_model_from_ldm,
32
+ create_diffusers_t5_model_from_checkpoint,
33
+ fetch_diffusers_config,
34
+ fetch_original_config,
35
+ is_clip_model_in_single_file,
36
+ is_t5_in_single_file,
37
+ load_single_file_checkpoint,
38
+ )
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ # Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
44
+ SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
45
+
46
+ if is_transformers_available():
47
+ import transformers
48
+ from transformers import PreTrainedModel, PreTrainedTokenizer
49
+
50
+
51
+ def load_single_file_sub_model(
52
+ library_name,
53
+ class_name,
54
+ name,
55
+ checkpoint,
56
+ pipelines,
57
+ is_pipeline_module,
58
+ cached_model_config_path,
59
+ original_config=None,
60
+ local_files_only=False,
61
+ torch_dtype=None,
62
+ is_legacy_loading=False,
63
+ **kwargs,
64
+ ):
65
+ if is_pipeline_module:
66
+ pipeline_module = getattr(pipelines, library_name)
67
+ class_obj = getattr(pipeline_module, class_name)
68
+ else:
69
+ # else we just import it from the library.
70
+ library = importlib.import_module(library_name)
71
+ class_obj = getattr(library, class_name)
72
+
73
+ if is_transformers_available():
74
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
75
+ else:
76
+ transformers_version = "N/A"
77
+
78
+ is_transformers_model = (
79
+ is_transformers_available()
80
+ and issubclass(class_obj, PreTrainedModel)
81
+ and transformers_version >= version.parse("4.20.0")
82
+ )
83
+ is_tokenizer = (
84
+ is_transformers_available()
85
+ and issubclass(class_obj, PreTrainedTokenizer)
86
+ and transformers_version >= version.parse("4.20.0")
87
+ )
88
+
89
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
90
+ is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
91
+ is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
92
+ is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
93
+
94
+ if is_diffusers_single_file_model:
95
+ load_method = getattr(class_obj, "from_single_file")
96
+
97
+ # We cannot provide two different config options to the `from_single_file` method
98
+ # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
99
+ if original_config:
100
+ cached_model_config_path = None
101
+
102
+ loaded_sub_model = load_method(
103
+ pretrained_model_link_or_path_or_dict=checkpoint,
104
+ original_config=original_config,
105
+ config=cached_model_config_path,
106
+ subfolder=name,
107
+ torch_dtype=torch_dtype,
108
+ local_files_only=local_files_only,
109
+ **kwargs,
110
+ )
111
+
112
+ elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
113
+ loaded_sub_model = create_diffusers_clip_model_from_ldm(
114
+ class_obj,
115
+ checkpoint=checkpoint,
116
+ config=cached_model_config_path,
117
+ subfolder=name,
118
+ torch_dtype=torch_dtype,
119
+ local_files_only=local_files_only,
120
+ is_legacy_loading=is_legacy_loading,
121
+ )
122
+
123
+ elif is_transformers_model and is_t5_in_single_file(checkpoint):
124
+ loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
125
+ class_obj,
126
+ checkpoint=checkpoint,
127
+ config=cached_model_config_path,
128
+ subfolder=name,
129
+ torch_dtype=torch_dtype,
130
+ local_files_only=local_files_only,
131
+ )
132
+
133
+ elif is_tokenizer and is_legacy_loading:
134
+ loaded_sub_model = _legacy_load_clip_tokenizer(
135
+ class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
136
+ )
137
+
138
+ elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
139
+ loaded_sub_model = _legacy_load_scheduler(
140
+ class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
141
+ )
142
+
143
+ else:
144
+ if not hasattr(class_obj, "from_pretrained"):
145
+ raise ValueError(
146
+ (
147
+ f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
148
+ " a supported loading method."
149
+ )
150
+ )
151
+
152
+ loading_kwargs = {}
153
+ loading_kwargs.update(
154
+ {
155
+ "pretrained_model_name_or_path": cached_model_config_path,
156
+ "subfolder": name,
157
+ "local_files_only": local_files_only,
158
+ }
159
+ )
160
+
161
+ # Schedulers and Tokenizers don't make use of torch_dtype
162
+ # Skip passing it to those objects
163
+ if issubclass(class_obj, torch.nn.Module):
164
+ loading_kwargs.update({"torch_dtype": torch_dtype})
165
+
166
+ if is_diffusers_model or is_transformers_model:
167
+ if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
168
+ raise SingleFileComponentError(
169
+ f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
170
+ )
171
+
172
+ load_method = getattr(class_obj, "from_pretrained")
173
+ loaded_sub_model = load_method(**loading_kwargs)
174
+
175
+ return loaded_sub_model
176
+
177
+
178
+ def _map_component_types_to_config_dict(component_types):
179
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
180
+ config_dict = {}
181
+ component_types.pop("self", None)
182
+
183
+ if is_transformers_available():
184
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
185
+ else:
186
+ transformers_version = "N/A"
187
+
188
+ for component_name, component_value in component_types.items():
189
+ is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
190
+ is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
191
+ is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
192
+
193
+ is_transformers_model = (
194
+ is_transformers_available()
195
+ and issubclass(component_value[0], PreTrainedModel)
196
+ and transformers_version >= version.parse("4.20.0")
197
+ )
198
+ is_transformers_tokenizer = (
199
+ is_transformers_available()
200
+ and issubclass(component_value[0], PreTrainedTokenizer)
201
+ and transformers_version >= version.parse("4.20.0")
202
+ )
203
+
204
+ if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
205
+ config_dict[component_name] = ["diffusers", component_value[0].__name__]
206
+
207
+ elif is_scheduler_enum or is_scheduler:
208
+ if is_scheduler_enum:
209
+ # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
210
+ # if the type hint is a KarrassDiffusionSchedulers enum
211
+ config_dict[component_name] = ["diffusers", "DDIMScheduler"]
212
+
213
+ elif is_scheduler:
214
+ config_dict[component_name] = ["diffusers", component_value[0].__name__]
215
+
216
+ elif (
217
+ is_transformers_model or is_transformers_tokenizer
218
+ ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
219
+ config_dict[component_name] = ["transformers", component_value[0].__name__]
220
+
221
+ else:
222
+ config_dict[component_name] = [None, None]
223
+
224
+ return config_dict
225
+
226
+
227
+ def _infer_pipeline_config_dict(pipeline_class):
228
+ parameters = inspect.signature(pipeline_class.__init__).parameters
229
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
230
+ component_types = pipeline_class._get_signature_types()
231
+
232
+ # Ignore parameters that are not required for the pipeline
233
+ component_types = {k: v for k, v in component_types.items() if k in required_parameters}
234
+ config_dict = _map_component_types_to_config_dict(component_types)
235
+
236
+ return config_dict
237
+
238
+
239
+ def _download_diffusers_model_config_from_hub(
240
+ pretrained_model_name_or_path,
241
+ cache_dir,
242
+ revision,
243
+ proxies,
244
+ force_download=None,
245
+ local_files_only=None,
246
+ token=None,
247
+ ):
248
+ allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
249
+ cached_model_path = snapshot_download(
250
+ pretrained_model_name_or_path,
251
+ cache_dir=cache_dir,
252
+ revision=revision,
253
+ proxies=proxies,
254
+ force_download=force_download,
255
+ local_files_only=local_files_only,
256
+ token=token,
257
+ allow_patterns=allow_patterns,
258
+ )
259
+
260
+ return cached_model_path
261
+
262
+
263
+ class FromSingleFileMixin:
264
+ """
265
+ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
266
+ """
267
+
268
+ @classmethod
269
+ @validate_hf_hub_args
270
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
271
+ r"""
272
+ Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
273
+ format. The pipeline is set in evaluation mode (`model.eval()`) by default.
274
+
275
+ Parameters:
276
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
277
+ Can be either:
278
+ - A link to the `.ckpt` file (for example
279
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
280
+ - A path to a *file* containing all pipeline weights.
281
+ torch_dtype (`str` or `torch.dtype`, *optional*):
282
+ Override the default `torch.dtype` and load the model with another dtype.
283
+ force_download (`bool`, *optional*, defaults to `False`):
284
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
285
+ cached versions if they exist.
286
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
287
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
288
+ is not used.
289
+
290
+ proxies (`Dict[str, str]`, *optional*):
291
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
292
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
293
+ local_files_only (`bool`, *optional*, defaults to `False`):
294
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
295
+ won't be downloaded from the Hub.
296
+ token (`str` or *bool*, *optional*):
297
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
298
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
299
+ revision (`str`, *optional*, defaults to `"main"`):
300
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
301
+ allowed by Git.
302
+ original_config_file (`str`, *optional*):
303
+ The path to the original config file that was used to train the model. If not provided, the config file
304
+ will be inferred from the checkpoint file.
305
+ config (`str`, *optional*):
306
+ Can be either:
307
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
308
+ hosted on the Hub.
309
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
310
+ component configs in Diffusers format.
311
+ kwargs (remaining dictionary of keyword arguments, *optional*):
312
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
313
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
314
+ below for more information.
315
+
316
+ Examples:
317
+
318
+ ```py
319
+ >>> from diffusers import StableDiffusionPipeline
320
+
321
+ >>> # Download pipeline from huggingface.co and cache.
322
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
323
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
324
+ ... )
325
+
326
+ >>> # Download pipeline from local file
327
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
328
+ >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
329
+
330
+ >>> # Enable float16 and move to GPU
331
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
332
+ ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
333
+ ... torch_dtype=torch.float16,
334
+ ... )
335
+ >>> pipeline.to("cuda")
336
+ ```
337
+
338
+ """
339
+ original_config_file = kwargs.pop("original_config_file", None)
340
+ config = kwargs.pop("config", None)
341
+ original_config = kwargs.pop("original_config", None)
342
+
343
+ if original_config_file is not None:
344
+ deprecation_message = (
345
+ "`original_config_file` argument is deprecated and will be removed in future versions."
346
+ "please use the `original_config` argument instead."
347
+ )
348
+ deprecate("original_config_file", "1.0.0", deprecation_message)
349
+ original_config = original_config_file
350
+
351
+ force_download = kwargs.pop("force_download", False)
352
+ proxies = kwargs.pop("proxies", None)
353
+ token = kwargs.pop("token", None)
354
+ cache_dir = kwargs.pop("cache_dir", None)
355
+ local_files_only = kwargs.pop("local_files_only", False)
356
+ revision = kwargs.pop("revision", None)
357
+ torch_dtype = kwargs.pop("torch_dtype", None)
358
+
359
+ is_legacy_loading = False
360
+
361
+ # We shouldn't allow configuring individual models components through a Pipeline creation method
362
+ # These model kwargs should be deprecated
363
+ scaling_factor = kwargs.get("scaling_factor", None)
364
+ if scaling_factor is not None:
365
+ deprecation_message = (
366
+ "Passing the `scaling_factor` argument to `from_single_file is deprecated "
367
+ "and will be ignored in future versions."
368
+ )
369
+ deprecate("scaling_factor", "1.0.0", deprecation_message)
370
+
371
+ if original_config is not None:
372
+ original_config = fetch_original_config(original_config, local_files_only=local_files_only)
373
+
374
+ from ..pipelines.pipeline_utils import _get_pipeline_class
375
+
376
+ pipeline_class = _get_pipeline_class(cls, config=None)
377
+
378
+ checkpoint = load_single_file_checkpoint(
379
+ pretrained_model_link_or_path,
380
+ force_download=force_download,
381
+ proxies=proxies,
382
+ token=token,
383
+ cache_dir=cache_dir,
384
+ local_files_only=local_files_only,
385
+ revision=revision,
386
+ )
387
+
388
+ if config is None:
389
+ config = fetch_diffusers_config(checkpoint)
390
+ default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
391
+ else:
392
+ default_pretrained_model_config_name = config
393
+
394
+ if not os.path.isdir(default_pretrained_model_config_name):
395
+ # Provided config is a repo_id
396
+ if default_pretrained_model_config_name.count("/") > 1:
397
+ raise ValueError(
398
+ f'The provided config "{config}"'
399
+ " is neither a valid local path nor a valid repo id. Please check the parameter."
400
+ )
401
+ try:
402
+ # Attempt to download the config files for the pipeline
403
+ cached_model_config_path = _download_diffusers_model_config_from_hub(
404
+ default_pretrained_model_config_name,
405
+ cache_dir=cache_dir,
406
+ revision=revision,
407
+ proxies=proxies,
408
+ force_download=force_download,
409
+ local_files_only=local_files_only,
410
+ token=token,
411
+ )
412
+ config_dict = pipeline_class.load_config(cached_model_config_path)
413
+
414
+ except LocalEntryNotFoundError:
415
+ # `local_files_only=True` but a local diffusers format model config is not available in the cache
416
+ # If `original_config` is not provided, we need override `local_files_only` to False
417
+ # to fetch the config files from the hub so that we have a way
418
+ # to configure the pipeline components.
419
+
420
+ if original_config is None:
421
+ logger.warning(
422
+ "`local_files_only` is True but no local configs were found for this checkpoint.\n"
423
+ "Attempting to download the necessary config files for this pipeline.\n"
424
+ )
425
+ cached_model_config_path = _download_diffusers_model_config_from_hub(
426
+ default_pretrained_model_config_name,
427
+ cache_dir=cache_dir,
428
+ revision=revision,
429
+ proxies=proxies,
430
+ force_download=force_download,
431
+ local_files_only=False,
432
+ token=token,
433
+ )
434
+ config_dict = pipeline_class.load_config(cached_model_config_path)
435
+
436
+ else:
437
+ # For backwards compatibility
438
+ # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
439
+ logger.warning(
440
+ "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
441
+ "This may lead to errors if the model components are not correctly inferred. \n"
442
+ "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
443
+ "e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
444
+ "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
445
+ "the necessary config files.\n"
446
+ )
447
+ is_legacy_loading = True
448
+ cached_model_config_path = None
449
+
450
+ config_dict = _infer_pipeline_config_dict(pipeline_class)
451
+ config_dict["_class_name"] = pipeline_class.__name__
452
+
453
+ else:
454
+ # Provided config is a path to a local directory attempt to load directly.
455
+ cached_model_config_path = default_pretrained_model_config_name
456
+ config_dict = pipeline_class.load_config(cached_model_config_path)
457
+
458
+ # pop out "_ignore_files" as it is only needed for download
459
+ config_dict.pop("_ignore_files", None)
460
+
461
+ expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
462
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
463
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
464
+
465
+ init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
466
+ init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
467
+ init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
468
+
469
+ from diffusers import pipelines
470
+
471
+ # remove `null` components
472
+ def load_module(name, value):
473
+ if value[0] is None:
474
+ return False
475
+ if name in passed_class_obj and passed_class_obj[name] is None:
476
+ return False
477
+ if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
478
+ return False
479
+
480
+ return True
481
+
482
+ init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
483
+
484
+ for name, (library_name, class_name) in logging.tqdm(
485
+ sorted(init_dict.items()), desc="Loading pipeline components..."
486
+ ):
487
+ loaded_sub_model = None
488
+ is_pipeline_module = hasattr(pipelines, library_name)
489
+
490
+ if name in passed_class_obj:
491
+ loaded_sub_model = passed_class_obj[name]
492
+
493
+ else:
494
+ try:
495
+ loaded_sub_model = load_single_file_sub_model(
496
+ library_name=library_name,
497
+ class_name=class_name,
498
+ name=name,
499
+ checkpoint=checkpoint,
500
+ is_pipeline_module=is_pipeline_module,
501
+ cached_model_config_path=cached_model_config_path,
502
+ pipelines=pipelines,
503
+ torch_dtype=torch_dtype,
504
+ original_config=original_config,
505
+ local_files_only=local_files_only,
506
+ is_legacy_loading=is_legacy_loading,
507
+ **kwargs,
508
+ )
509
+ except SingleFileComponentError as e:
510
+ raise SingleFileComponentError(
511
+ (
512
+ f"{e.message}\n"
513
+ f"Please load the component before passing it in as an argument to `from_single_file`.\n"
514
+ f"\n"
515
+ f"{name} = {class_name}.from_pretrained('...')\n"
516
+ f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
517
+ f"\n"
518
+ )
519
+ )
520
+
521
+ init_kwargs[name] = loaded_sub_model
522
+
523
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
524
+ passed_modules = list(passed_class_obj.keys())
525
+ optional_modules = pipeline_class._optional_components
526
+
527
+ if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
528
+ for module in missing_modules:
529
+ init_kwargs[module] = passed_class_obj.get(module, None)
530
+ elif len(missing_modules) > 0:
531
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
532
+ raise ValueError(
533
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
534
+ )
535
+
536
+ # deprecated kwargs
537
+ load_safety_checker = kwargs.pop("load_safety_checker", None)
538
+ if load_safety_checker is not None:
539
+ deprecation_message = (
540
+ "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
541
+ "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
542
+ )
543
+ deprecate("load_safety_checker", "1.0.0", deprecation_message)
544
+
545
+ safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
546
+ init_kwargs.update(safety_checker_components)
547
+
548
+ pipe = pipeline_class(**init_kwargs)
549
+
550
+ return pipe
icedit/diffusers/loaders/single_file_model.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import inspect
16
+ import re
17
+ from contextlib import nullcontext
18
+ from typing import Optional
19
+
20
+ import torch
21
+ from huggingface_hub.utils import validate_hf_hub_args
22
+
23
+ from ..quantizers import DiffusersAutoQuantizer
24
+ from ..utils import deprecate, is_accelerate_available, logging
25
+ from .single_file_utils import (
26
+ SingleFileComponentError,
27
+ convert_animatediff_checkpoint_to_diffusers,
28
+ convert_autoencoder_dc_checkpoint_to_diffusers,
29
+ convert_controlnet_checkpoint,
30
+ convert_flux_transformer_checkpoint_to_diffusers,
31
+ convert_hunyuan_video_transformer_to_diffusers,
32
+ convert_ldm_unet_checkpoint,
33
+ convert_ldm_vae_checkpoint,
34
+ convert_ltx_transformer_checkpoint_to_diffusers,
35
+ convert_ltx_vae_checkpoint_to_diffusers,
36
+ convert_mochi_transformer_checkpoint_to_diffusers,
37
+ convert_sd3_transformer_checkpoint_to_diffusers,
38
+ convert_stable_cascade_unet_single_file_to_diffusers,
39
+ create_controlnet_diffusers_config_from_ldm,
40
+ create_unet_diffusers_config_from_ldm,
41
+ create_vae_diffusers_config_from_ldm,
42
+ fetch_diffusers_config,
43
+ fetch_original_config,
44
+ load_single_file_checkpoint,
45
+ )
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ if is_accelerate_available():
52
+ from accelerate import init_empty_weights
53
+
54
+ from ..models.modeling_utils import load_model_dict_into_meta
55
+
56
+
57
+ SINGLE_FILE_LOADABLE_CLASSES = {
58
+ "StableCascadeUNet": {
59
+ "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
60
+ },
61
+ "UNet2DConditionModel": {
62
+ "checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
63
+ "config_mapping_fn": create_unet_diffusers_config_from_ldm,
64
+ "default_subfolder": "unet",
65
+ "legacy_kwargs": {
66
+ "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
67
+ },
68
+ },
69
+ "AutoencoderKL": {
70
+ "checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
71
+ "config_mapping_fn": create_vae_diffusers_config_from_ldm,
72
+ "default_subfolder": "vae",
73
+ },
74
+ "ControlNetModel": {
75
+ "checkpoint_mapping_fn": convert_controlnet_checkpoint,
76
+ "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
77
+ },
78
+ "SD3Transformer2DModel": {
79
+ "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
80
+ "default_subfolder": "transformer",
81
+ },
82
+ "MotionAdapter": {
83
+ "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
84
+ },
85
+ "SparseControlNetModel": {
86
+ "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
87
+ },
88
+ "FluxTransformer2DModel": {
89
+ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
90
+ "default_subfolder": "transformer",
91
+ },
92
+ "LTXVideoTransformer3DModel": {
93
+ "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
94
+ "default_subfolder": "transformer",
95
+ },
96
+ "AutoencoderKLLTXVideo": {
97
+ "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
98
+ "default_subfolder": "vae",
99
+ },
100
+ "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
101
+ "MochiTransformer3DModel": {
102
+ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
103
+ "default_subfolder": "transformer",
104
+ },
105
+ "HunyuanVideoTransformer3DModel": {
106
+ "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
107
+ "default_subfolder": "transformer",
108
+ },
109
+ }
110
+
111
+
112
+ def _get_single_file_loadable_mapping_class(cls):
113
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
114
+ for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
115
+ loadable_class = getattr(diffusers_module, loadable_class_str)
116
+
117
+ if issubclass(cls, loadable_class):
118
+ return loadable_class_str
119
+
120
+ return None
121
+
122
+
123
+ def _get_mapping_function_kwargs(mapping_fn, **kwargs):
124
+ parameters = inspect.signature(mapping_fn).parameters
125
+
126
+ mapping_kwargs = {}
127
+ for parameter in parameters:
128
+ if parameter in kwargs:
129
+ mapping_kwargs[parameter] = kwargs[parameter]
130
+
131
+ return mapping_kwargs
132
+
133
+
134
+ class FromOriginalModelMixin:
135
+ """
136
+ Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
137
+ """
138
+
139
+ @classmethod
140
+ @validate_hf_hub_args
141
+ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
142
+ r"""
143
+ Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
144
+ is set in evaluation mode (`model.eval()`) by default.
145
+
146
+ Parameters:
147
+ pretrained_model_link_or_path_or_dict (`str`, *optional*):
148
+ Can be either:
149
+ - A link to the `.safetensors` or `.ckpt` file (for example
150
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
151
+ - A path to a local *file* containing the weights of the component model.
152
+ - A state dict containing the component model weights.
153
+ config (`str`, *optional*):
154
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
155
+ on the Hub.
156
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
157
+ configs in Diffusers format.
158
+ subfolder (`str`, *optional*, defaults to `""`):
159
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
160
+ original_config (`str`, *optional*):
161
+ Dict or path to a yaml file containing the configuration for the model in its original format.
162
+ If a dict is provided, it will be used to initialize the model configuration.
163
+ torch_dtype (`str` or `torch.dtype`, *optional*):
164
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
165
+ dtype is automatically derived from the model's weights.
166
+ force_download (`bool`, *optional*, defaults to `False`):
167
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
168
+ cached versions if they exist.
169
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
170
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
171
+ is not used.
172
+
173
+ proxies (`Dict[str, str]`, *optional*):
174
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
175
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
176
+ local_files_only (`bool`, *optional*, defaults to `False`):
177
+ Whether to only load local model weights and configuration files or not. If set to True, the model
178
+ won't be downloaded from the Hub.
179
+ token (`str` or *bool*, *optional*):
180
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
181
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
182
+ revision (`str`, *optional*, defaults to `"main"`):
183
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
184
+ allowed by Git.
185
+ kwargs (remaining dictionary of keyword arguments, *optional*):
186
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
187
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
188
+ method. See example below for more information.
189
+
190
+ ```py
191
+ >>> from diffusers import StableCascadeUNet
192
+
193
+ >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
194
+ >>> model = StableCascadeUNet.from_single_file(ckpt_path)
195
+ ```
196
+ """
197
+
198
+ mapping_class_name = _get_single_file_loadable_mapping_class(cls)
199
+ # if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
200
+ if mapping_class_name is None:
201
+ raise ValueError(
202
+ f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
203
+ )
204
+
205
+ pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
206
+ if pretrained_model_link_or_path is not None:
207
+ deprecation_message = (
208
+ "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
209
+ )
210
+ deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
211
+ pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
212
+
213
+ config = kwargs.pop("config", None)
214
+ original_config = kwargs.pop("original_config", None)
215
+
216
+ if config is not None and original_config is not None:
217
+ raise ValueError(
218
+ "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
219
+ )
220
+
221
+ force_download = kwargs.pop("force_download", False)
222
+ proxies = kwargs.pop("proxies", None)
223
+ token = kwargs.pop("token", None)
224
+ cache_dir = kwargs.pop("cache_dir", None)
225
+ local_files_only = kwargs.pop("local_files_only", None)
226
+ subfolder = kwargs.pop("subfolder", None)
227
+ revision = kwargs.pop("revision", None)
228
+ config_revision = kwargs.pop("config_revision", None)
229
+ torch_dtype = kwargs.pop("torch_dtype", None)
230
+ quantization_config = kwargs.pop("quantization_config", None)
231
+ device = kwargs.pop("device", None)
232
+
233
+ if isinstance(pretrained_model_link_or_path_or_dict, dict):
234
+ checkpoint = pretrained_model_link_or_path_or_dict
235
+ else:
236
+ checkpoint = load_single_file_checkpoint(
237
+ pretrained_model_link_or_path_or_dict,
238
+ force_download=force_download,
239
+ proxies=proxies,
240
+ token=token,
241
+ cache_dir=cache_dir,
242
+ local_files_only=local_files_only,
243
+ revision=revision,
244
+ )
245
+ if quantization_config is not None:
246
+ hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
247
+ hf_quantizer.validate_environment()
248
+
249
+ else:
250
+ hf_quantizer = None
251
+
252
+ mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
253
+
254
+ checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
255
+ if original_config is not None:
256
+ if "config_mapping_fn" in mapping_functions:
257
+ config_mapping_fn = mapping_functions["config_mapping_fn"]
258
+ else:
259
+ config_mapping_fn = None
260
+
261
+ if config_mapping_fn is None:
262
+ raise ValueError(
263
+ (
264
+ f"`original_config` has been provided for {mapping_class_name} but no mapping function"
265
+ "was found to convert the original config to a Diffusers config in"
266
+ "`diffusers.loaders.single_file_utils`"
267
+ )
268
+ )
269
+
270
+ if isinstance(original_config, str):
271
+ # If original_config is a URL or filepath fetch the original_config dict
272
+ original_config = fetch_original_config(original_config, local_files_only=local_files_only)
273
+
274
+ config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
275
+ diffusers_model_config = config_mapping_fn(
276
+ original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
277
+ )
278
+ else:
279
+ if config is not None:
280
+ if isinstance(config, str):
281
+ default_pretrained_model_config_name = config
282
+ else:
283
+ raise ValueError(
284
+ (
285
+ "Invalid `config` argument. Please provide a string representing a repo id"
286
+ "or path to a local Diffusers model repo."
287
+ )
288
+ )
289
+
290
+ else:
291
+ config = fetch_diffusers_config(checkpoint)
292
+ default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
293
+
294
+ if "default_subfolder" in mapping_functions:
295
+ subfolder = mapping_functions["default_subfolder"]
296
+
297
+ subfolder = subfolder or config.pop(
298
+ "subfolder", None
299
+ ) # some configs contain a subfolder key, e.g. StableCascadeUNet
300
+
301
+ diffusers_model_config = cls.load_config(
302
+ pretrained_model_name_or_path=default_pretrained_model_config_name,
303
+ subfolder=subfolder,
304
+ local_files_only=local_files_only,
305
+ token=token,
306
+ revision=config_revision,
307
+ )
308
+ expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
309
+
310
+ # Map legacy kwargs to new kwargs
311
+ if "legacy_kwargs" in mapping_functions:
312
+ legacy_kwargs = mapping_functions["legacy_kwargs"]
313
+ for legacy_key, new_key in legacy_kwargs.items():
314
+ if legacy_key in kwargs:
315
+ kwargs[new_key] = kwargs.pop(legacy_key)
316
+
317
+ model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
318
+ diffusers_model_config.update(model_kwargs)
319
+
320
+ checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
321
+ diffusers_format_checkpoint = checkpoint_mapping_fn(
322
+ config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
323
+ )
324
+ if not diffusers_format_checkpoint:
325
+ raise SingleFileComponentError(
326
+ f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
327
+ )
328
+
329
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
330
+ with ctx():
331
+ model = cls.from_config(diffusers_model_config)
332
+
333
+ # Check if `_keep_in_fp32_modules` is not None
334
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
335
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
336
+ )
337
+ if use_keep_in_fp32_modules:
338
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
339
+ if not isinstance(keep_in_fp32_modules, list):
340
+ keep_in_fp32_modules = [keep_in_fp32_modules]
341
+
342
+ else:
343
+ keep_in_fp32_modules = []
344
+
345
+ if hf_quantizer is not None:
346
+ hf_quantizer.preprocess_model(
347
+ model=model,
348
+ device_map=None,
349
+ state_dict=diffusers_format_checkpoint,
350
+ keep_in_fp32_modules=keep_in_fp32_modules,
351
+ )
352
+
353
+ if is_accelerate_available():
354
+ param_device = torch.device(device) if device else torch.device("cpu")
355
+ unexpected_keys = load_model_dict_into_meta(
356
+ model,
357
+ diffusers_format_checkpoint,
358
+ dtype=torch_dtype,
359
+ device=param_device,
360
+ hf_quantizer=hf_quantizer,
361
+ keep_in_fp32_modules=keep_in_fp32_modules,
362
+ )
363
+
364
+ else:
365
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
366
+
367
+ if model._keys_to_ignore_on_load_unexpected is not None:
368
+ for pat in model._keys_to_ignore_on_load_unexpected:
369
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
370
+
371
+ if len(unexpected_keys) > 0:
372
+ logger.warning(
373
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
374
+ )
375
+
376
+ if hf_quantizer is not None:
377
+ hf_quantizer.postprocess_model(model)
378
+ model.hf_quantizer = hf_quantizer
379
+
380
+ if torch_dtype is not None and hf_quantizer is None:
381
+ model.to(torch_dtype)
382
+
383
+ model.eval()
384
+
385
+ return model
icedit/diffusers/loaders/single_file_utils.py ADDED
The diff for this file is too large to render. See raw diff
 
icedit/diffusers/loaders/textual_inversion.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, List, Optional, Union
15
+
16
+ import safetensors
17
+ import torch
18
+ from huggingface_hub.utils import validate_hf_hub_args
19
+ from torch import nn
20
+
21
+ from ..models.modeling_utils import load_state_dict
22
+ from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
23
+
24
+
25
+ if is_transformers_available():
26
+ from transformers import PreTrainedModel, PreTrainedTokenizer
27
+
28
+ if is_accelerate_available():
29
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ TEXT_INVERSION_NAME = "learned_embeds.bin"
34
+ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
35
+
36
+
37
+ @validate_hf_hub_args
38
+ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
39
+ cache_dir = kwargs.pop("cache_dir", None)
40
+ force_download = kwargs.pop("force_download", False)
41
+ proxies = kwargs.pop("proxies", None)
42
+ local_files_only = kwargs.pop("local_files_only", None)
43
+ token = kwargs.pop("token", None)
44
+ revision = kwargs.pop("revision", None)
45
+ subfolder = kwargs.pop("subfolder", None)
46
+ weight_name = kwargs.pop("weight_name", None)
47
+ use_safetensors = kwargs.pop("use_safetensors", None)
48
+
49
+ allow_pickle = False
50
+ if use_safetensors is None:
51
+ use_safetensors = True
52
+ allow_pickle = True
53
+
54
+ user_agent = {
55
+ "file_type": "text_inversion",
56
+ "framework": "pytorch",
57
+ }
58
+ state_dicts = []
59
+ for pretrained_model_name_or_path in pretrained_model_name_or_paths:
60
+ if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
61
+ # 3.1. Load textual inversion file
62
+ model_file = None
63
+
64
+ # Let's first try to load .safetensors weights
65
+ if (use_safetensors and weight_name is None) or (
66
+ weight_name is not None and weight_name.endswith(".safetensors")
67
+ ):
68
+ try:
69
+ model_file = _get_model_file(
70
+ pretrained_model_name_or_path,
71
+ weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
72
+ cache_dir=cache_dir,
73
+ force_download=force_download,
74
+ proxies=proxies,
75
+ local_files_only=local_files_only,
76
+ token=token,
77
+ revision=revision,
78
+ subfolder=subfolder,
79
+ user_agent=user_agent,
80
+ )
81
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
82
+ except Exception as e:
83
+ if not allow_pickle:
84
+ raise e
85
+
86
+ model_file = None
87
+
88
+ if model_file is None:
89
+ model_file = _get_model_file(
90
+ pretrained_model_name_or_path,
91
+ weights_name=weight_name or TEXT_INVERSION_NAME,
92
+ cache_dir=cache_dir,
93
+ force_download=force_download,
94
+ proxies=proxies,
95
+ local_files_only=local_files_only,
96
+ token=token,
97
+ revision=revision,
98
+ subfolder=subfolder,
99
+ user_agent=user_agent,
100
+ )
101
+ state_dict = load_state_dict(model_file)
102
+ else:
103
+ state_dict = pretrained_model_name_or_path
104
+
105
+ state_dicts.append(state_dict)
106
+
107
+ return state_dicts
108
+
109
+
110
+ class TextualInversionLoaderMixin:
111
+ r"""
112
+ Load Textual Inversion tokens and embeddings to the tokenizer and text encoder.
113
+ """
114
+
115
+ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
116
+ r"""
117
+ Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
118
+ be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
119
+ inversion token or if the textual inversion token is a single vector, the input prompt is returned.
120
+
121
+ Parameters:
122
+ prompt (`str` or list of `str`):
123
+ The prompt or prompts to guide the image generation.
124
+ tokenizer (`PreTrainedTokenizer`):
125
+ The tokenizer responsible for encoding the prompt into input tokens.
126
+
127
+ Returns:
128
+ `str` or list of `str`: The converted prompt
129
+ """
130
+ if not isinstance(prompt, List):
131
+ prompts = [prompt]
132
+ else:
133
+ prompts = prompt
134
+
135
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
136
+
137
+ if not isinstance(prompt, List):
138
+ return prompts[0]
139
+
140
+ return prompts
141
+
142
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
143
+ r"""
144
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
145
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
146
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
147
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
148
+
149
+ Parameters:
150
+ prompt (`str`):
151
+ The prompt to guide the image generation.
152
+ tokenizer (`PreTrainedTokenizer`):
153
+ The tokenizer responsible for encoding the prompt into input tokens.
154
+
155
+ Returns:
156
+ `str`: The converted prompt
157
+ """
158
+ tokens = tokenizer.tokenize(prompt)
159
+ unique_tokens = set(tokens)
160
+ for token in unique_tokens:
161
+ if token in tokenizer.added_tokens_encoder:
162
+ replacement = token
163
+ i = 1
164
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
165
+ replacement += f" {token}_{i}"
166
+ i += 1
167
+
168
+ prompt = prompt.replace(token, replacement)
169
+
170
+ return prompt
171
+
172
+ def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
173
+ if tokenizer is None:
174
+ raise ValueError(
175
+ f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
176
+ f" `{self.load_textual_inversion.__name__}`"
177
+ )
178
+
179
+ if text_encoder is None:
180
+ raise ValueError(
181
+ f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
182
+ f" `{self.load_textual_inversion.__name__}`"
183
+ )
184
+
185
+ if len(pretrained_model_name_or_paths) > 1 and len(pretrained_model_name_or_paths) != len(tokens):
186
+ raise ValueError(
187
+ f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
188
+ f"Make sure both lists have the same length."
189
+ )
190
+
191
+ valid_tokens = [t for t in tokens if t is not None]
192
+ if len(set(valid_tokens)) < len(valid_tokens):
193
+ raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
194
+
195
+ @staticmethod
196
+ def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
197
+ all_tokens = []
198
+ all_embeddings = []
199
+ for state_dict, token in zip(state_dicts, tokens):
200
+ if isinstance(state_dict, torch.Tensor):
201
+ if token is None:
202
+ raise ValueError(
203
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
204
+ )
205
+ loaded_token = token
206
+ embedding = state_dict
207
+ elif len(state_dict) == 1:
208
+ # diffusers
209
+ loaded_token, embedding = next(iter(state_dict.items()))
210
+ elif "string_to_param" in state_dict:
211
+ # A1111
212
+ loaded_token = state_dict["name"]
213
+ embedding = state_dict["string_to_param"]["*"]
214
+ else:
215
+ raise ValueError(
216
+ f"Loaded state dictionary is incorrect: {state_dict}. \n\n"
217
+ "Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
218
+ " input key."
219
+ )
220
+
221
+ if token is not None and loaded_token != token:
222
+ logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
223
+ else:
224
+ token = loaded_token
225
+
226
+ if token in tokenizer.get_vocab():
227
+ raise ValueError(
228
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
229
+ )
230
+
231
+ all_tokens.append(token)
232
+ all_embeddings.append(embedding)
233
+
234
+ return all_tokens, all_embeddings
235
+
236
+ @staticmethod
237
+ def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
238
+ all_tokens = []
239
+ all_embeddings = []
240
+
241
+ for embedding, token in zip(embeddings, tokens):
242
+ if f"{token}_1" in tokenizer.get_vocab():
243
+ multi_vector_tokens = [token]
244
+ i = 1
245
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
246
+ multi_vector_tokens.append(f"{token}_{i}")
247
+ i += 1
248
+
249
+ raise ValueError(
250
+ f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
251
+ )
252
+
253
+ is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
254
+ if is_multi_vector:
255
+ all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
256
+ all_embeddings += [e for e in embedding] # noqa: C416
257
+ else:
258
+ all_tokens += [token]
259
+ all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
260
+
261
+ return all_tokens, all_embeddings
262
+
263
+ @validate_hf_hub_args
264
+ def load_textual_inversion(
265
+ self,
266
+ pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
267
+ token: Optional[Union[str, List[str]]] = None,
268
+ tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
269
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
270
+ **kwargs,
271
+ ):
272
+ r"""
273
+ Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
274
+ Automatic1111 formats are supported).
275
+
276
+ Parameters:
277
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
278
+ Can be either one of the following or a list of them:
279
+
280
+ - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
281
+ pretrained model hosted on the Hub.
282
+ - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
283
+ inversion weights.
284
+ - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
285
+ - A [torch state
286
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
287
+
288
+ token (`str` or `List[str]`, *optional*):
289
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
290
+ list, then `token` must also be a list of equal length.
291
+ text_encoder ([`~transformers.CLIPTextModel`], *optional*):
292
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
293
+ If not specified, function will take self.tokenizer.
294
+ tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
295
+ A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
296
+ weight_name (`str`, *optional*):
297
+ Name of a custom weight file. This should be used when:
298
+
299
+ - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
300
+ name such as `text_inv.bin`.
301
+ - The saved textual inversion file is in the Automatic1111 format.
302
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
303
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
304
+ is not used.
305
+ force_download (`bool`, *optional*, defaults to `False`):
306
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
307
+ cached versions if they exist.
308
+
309
+ proxies (`Dict[str, str]`, *optional*):
310
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
311
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
312
+ local_files_only (`bool`, *optional*, defaults to `False`):
313
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
314
+ won't be downloaded from the Hub.
315
+ token (`str` or *bool*, *optional*):
316
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
317
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
318
+ revision (`str`, *optional*, defaults to `"main"`):
319
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
320
+ allowed by Git.
321
+ subfolder (`str`, *optional*, defaults to `""`):
322
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
323
+ mirror (`str`, *optional*):
324
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
325
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
326
+ information.
327
+
328
+ Example:
329
+
330
+ To load a Textual Inversion embedding vector in 🤗 Diffusers format:
331
+
332
+ ```py
333
+ from diffusers import StableDiffusionPipeline
334
+ import torch
335
+
336
+ model_id = "runwayml/stable-diffusion-v1-5"
337
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
338
+
339
+ pipe.load_textual_inversion("sd-concepts-library/cat-toy")
340
+
341
+ prompt = "A <cat-toy> backpack"
342
+
343
+ image = pipe(prompt, num_inference_steps=50).images[0]
344
+ image.save("cat-backpack.png")
345
+ ```
346
+
347
+ To load a Textual Inversion embedding vector in Automatic1111 format, make sure to download the vector first
348
+ (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
349
+ locally:
350
+
351
+ ```py
352
+ from diffusers import StableDiffusionPipeline
353
+ import torch
354
+
355
+ model_id = "runwayml/stable-diffusion-v1-5"
356
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
357
+
358
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
359
+
360
+ prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
361
+
362
+ image = pipe(prompt, num_inference_steps=50).images[0]
363
+ image.save("character.png")
364
+ ```
365
+
366
+ """
367
+ # 1. Set correct tokenizer and text encoder
368
+ tokenizer = tokenizer or getattr(self, "tokenizer", None)
369
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
370
+
371
+ # 2. Normalize inputs
372
+ pretrained_model_name_or_paths = (
373
+ [pretrained_model_name_or_path]
374
+ if not isinstance(pretrained_model_name_or_path, list)
375
+ else pretrained_model_name_or_path
376
+ )
377
+ tokens = [token] if not isinstance(token, list) else token
378
+ if tokens[0] is None:
379
+ tokens = tokens * len(pretrained_model_name_or_paths)
380
+
381
+ # 3. Check inputs
382
+ self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
383
+
384
+ # 4. Load state dicts of textual embeddings
385
+ state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
386
+
387
+ # 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens
388
+ if len(tokens) > 1 and len(state_dicts) == 1:
389
+ if isinstance(state_dicts[0], torch.Tensor):
390
+ state_dicts = list(state_dicts[0])
391
+ if len(tokens) != len(state_dicts):
392
+ raise ValueError(
393
+ f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} "
394
+ f"Make sure both have the same length."
395
+ )
396
+
397
+ # 4. Retrieve tokens and embeddings
398
+ tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
399
+
400
+ # 5. Extend tokens and embeddings for multi vector
401
+ tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
402
+
403
+ # 6. Make sure all embeddings have the correct size
404
+ expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
405
+ if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
406
+ raise ValueError(
407
+ "Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
408
+ "to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
409
+ )
410
+
411
+ # 7. Now we can be sure that loading the embedding matrix works
412
+ # < Unsafe code:
413
+
414
+ # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
415
+ is_model_cpu_offload = False
416
+ is_sequential_cpu_offload = False
417
+ if self.hf_device_map is None:
418
+ for _, component in self.components.items():
419
+ if isinstance(component, nn.Module):
420
+ if hasattr(component, "_hf_hook"):
421
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
422
+ is_sequential_cpu_offload = (
423
+ isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
424
+ or hasattr(component._hf_hook, "hooks")
425
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
426
+ )
427
+ logger.info(
428
+ "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
429
+ )
430
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
431
+
432
+ # 7.2 save expected device and dtype
433
+ device = text_encoder.device
434
+ dtype = text_encoder.dtype
435
+
436
+ # 7.3 Increase token embedding matrix
437
+ text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
438
+ input_embeddings = text_encoder.get_input_embeddings().weight
439
+
440
+ # 7.4 Load token and embedding
441
+ for token, embedding in zip(tokens, embeddings):
442
+ # add tokens and get ids
443
+ tokenizer.add_tokens(token)
444
+ token_id = tokenizer.convert_tokens_to_ids(token)
445
+ input_embeddings.data[token_id] = embedding
446
+ logger.info(f"Loaded textual inversion embedding for {token}.")
447
+
448
+ input_embeddings.to(dtype=dtype, device=device)
449
+
450
+ # 7.5 Offload the model again
451
+ if is_model_cpu_offload:
452
+ self.enable_model_cpu_offload()
453
+ elif is_sequential_cpu_offload:
454
+ self.enable_sequential_cpu_offload()
455
+
456
+ # / Unsafe Code >
457
+
458
+ def unload_textual_inversion(
459
+ self,
460
+ tokens: Optional[Union[str, List[str]]] = None,
461
+ tokenizer: Optional["PreTrainedTokenizer"] = None,
462
+ text_encoder: Optional["PreTrainedModel"] = None,
463
+ ):
464
+ r"""
465
+ Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
466
+
467
+ Example:
468
+ ```py
469
+ from diffusers import AutoPipelineForText2Image
470
+ import torch
471
+
472
+ pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
473
+
474
+ # Example 1
475
+ pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
476
+ pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
477
+
478
+ # Remove all token embeddings
479
+ pipeline.unload_textual_inversion()
480
+
481
+ # Example 2
482
+ pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
483
+ pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
484
+
485
+ # Remove just one token
486
+ pipeline.unload_textual_inversion("<moe-bius>")
487
+
488
+ # Example 3: unload from SDXL
489
+ pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
490
+ embedding_path = hf_hub_download(
491
+ repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model"
492
+ )
493
+
494
+ # load embeddings to the text encoders
495
+ state_dict = load_file(embedding_path)
496
+
497
+ # load embeddings of text_encoder 1 (CLIP ViT-L/14)
498
+ pipeline.load_textual_inversion(
499
+ state_dict["clip_l"],
500
+ tokens=["<s0>", "<s1>"],
501
+ text_encoder=pipeline.text_encoder,
502
+ tokenizer=pipeline.tokenizer,
503
+ )
504
+ # load embeddings of text_encoder 2 (CLIP ViT-G/14)
505
+ pipeline.load_textual_inversion(
506
+ state_dict["clip_g"],
507
+ tokens=["<s0>", "<s1>"],
508
+ text_encoder=pipeline.text_encoder_2,
509
+ tokenizer=pipeline.tokenizer_2,
510
+ )
511
+
512
+ # Unload explicitly from both text encoders and tokenizers
513
+ pipeline.unload_textual_inversion(
514
+ tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
515
+ )
516
+ pipeline.unload_textual_inversion(
517
+ tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2
518
+ )
519
+ ```
520
+ """
521
+
522
+ tokenizer = tokenizer or getattr(self, "tokenizer", None)
523
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
524
+
525
+ # Get textual inversion tokens and ids
526
+ token_ids = []
527
+ last_special_token_id = None
528
+
529
+ if tokens:
530
+ if isinstance(tokens, str):
531
+ tokens = [tokens]
532
+ for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
533
+ if not added_token.special:
534
+ if added_token.content in tokens:
535
+ token_ids.append(added_token_id)
536
+ else:
537
+ last_special_token_id = added_token_id
538
+ if len(token_ids) == 0:
539
+ raise ValueError("No tokens to remove found")
540
+ else:
541
+ tokens = []
542
+ for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
543
+ if not added_token.special:
544
+ token_ids.append(added_token_id)
545
+ tokens.append(added_token.content)
546
+ else:
547
+ last_special_token_id = added_token_id
548
+
549
+ # Delete from tokenizer
550
+ for token_id, token_to_remove in zip(token_ids, tokens):
551
+ del tokenizer._added_tokens_decoder[token_id]
552
+ del tokenizer._added_tokens_encoder[token_to_remove]
553
+
554
+ # Make all token ids sequential in tokenizer
555
+ key_id = 1
556
+ for token_id in tokenizer.added_tokens_decoder:
557
+ if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
558
+ token = tokenizer._added_tokens_decoder[token_id]
559
+ tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
560
+ del tokenizer._added_tokens_decoder[token_id]
561
+ tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
562
+ key_id += 1
563
+ tokenizer._update_trie()
564
+ # set correct total vocab size after removing tokens
565
+ tokenizer._update_total_vocab_size()
566
+
567
+ # Delete from text encoder
568
+ text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
569
+ temp_text_embedding_weights = text_encoder.get_input_embeddings().weight
570
+ text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1]
571
+ to_append = []
572
+ for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]):
573
+ if i not in token_ids:
574
+ to_append.append(temp_text_embedding_weights[i].unsqueeze(0))
575
+ if len(to_append) > 0:
576
+ to_append = torch.cat(to_append, dim=0)
577
+ text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0)
578
+ text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim)
579
+ text_embeddings_filtered.weight.data = text_embedding_weights
580
+ text_encoder.set_input_embeddings(text_embeddings_filtered)
icedit/diffusers/loaders/transformer_flux.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from contextlib import nullcontext
15
+
16
+ from ..models.embeddings import (
17
+ ImageProjection,
18
+ MultiIPAdapterImageProjection,
19
+ )
20
+ from ..models.modeling_utils import load_model_dict_into_meta
21
+ from ..utils import (
22
+ is_accelerate_available,
23
+ is_torch_version,
24
+ logging,
25
+ )
26
+
27
+
28
+ if is_accelerate_available():
29
+ pass
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class FluxTransformer2DLoadersMixin:
35
+ """
36
+ Load layers into a [`FluxTransformer2DModel`].
37
+ """
38
+
39
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
40
+ if low_cpu_mem_usage:
41
+ if is_accelerate_available():
42
+ from accelerate import init_empty_weights
43
+
44
+ else:
45
+ low_cpu_mem_usage = False
46
+ logger.warning(
47
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
48
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
49
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
50
+ " install accelerate\n```\n."
51
+ )
52
+
53
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
54
+ raise NotImplementedError(
55
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
56
+ " `low_cpu_mem_usage=False`."
57
+ )
58
+
59
+ updated_state_dict = {}
60
+ image_projection = None
61
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
62
+
63
+ if "proj.weight" in state_dict:
64
+ # IP-Adapter
65
+ num_image_text_embeds = 4
66
+ if state_dict["proj.weight"].shape[0] == 65536:
67
+ num_image_text_embeds = 16
68
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
69
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
70
+
71
+ with init_context():
72
+ image_projection = ImageProjection(
73
+ cross_attention_dim=cross_attention_dim,
74
+ image_embed_dim=clip_embeddings_dim,
75
+ num_image_text_embeds=num_image_text_embeds,
76
+ )
77
+
78
+ for key, value in state_dict.items():
79
+ diffusers_name = key.replace("proj", "image_embeds")
80
+ updated_state_dict[diffusers_name] = value
81
+
82
+ if not low_cpu_mem_usage:
83
+ image_projection.load_state_dict(updated_state_dict, strict=True)
84
+ else:
85
+ load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
86
+
87
+ return image_projection
88
+
89
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
90
+ from ..models.attention_processor import (
91
+ FluxIPAdapterJointAttnProcessor2_0,
92
+ )
93
+
94
+ if low_cpu_mem_usage:
95
+ if is_accelerate_available():
96
+ from accelerate import init_empty_weights
97
+
98
+ else:
99
+ low_cpu_mem_usage = False
100
+ logger.warning(
101
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
102
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
103
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
104
+ " install accelerate\n```\n."
105
+ )
106
+
107
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
108
+ raise NotImplementedError(
109
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
110
+ " `low_cpu_mem_usage=False`."
111
+ )
112
+
113
+ # set ip-adapter cross-attention processors & load state_dict
114
+ attn_procs = {}
115
+ key_id = 0
116
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
117
+ for name in self.attn_processors.keys():
118
+ if name.startswith("single_transformer_blocks"):
119
+ attn_processor_class = self.attn_processors[name].__class__
120
+ attn_procs[name] = attn_processor_class()
121
+ else:
122
+ cross_attention_dim = self.config.joint_attention_dim
123
+ hidden_size = self.inner_dim
124
+ attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
125
+ num_image_text_embeds = []
126
+ for state_dict in state_dicts:
127
+ if "proj.weight" in state_dict["image_proj"]:
128
+ num_image_text_embed = 4
129
+ if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
130
+ num_image_text_embed = 16
131
+ # IP-Adapter
132
+ num_image_text_embeds += [num_image_text_embed]
133
+
134
+ with init_context():
135
+ attn_procs[name] = attn_processor_class(
136
+ hidden_size=hidden_size,
137
+ cross_attention_dim=cross_attention_dim,
138
+ scale=1.0,
139
+ num_tokens=num_image_text_embeds,
140
+ dtype=self.dtype,
141
+ device=self.device,
142
+ )
143
+
144
+ value_dict = {}
145
+ for i, state_dict in enumerate(state_dicts):
146
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
147
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
148
+ value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
149
+ value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
150
+
151
+ if not low_cpu_mem_usage:
152
+ attn_procs[name].load_state_dict(value_dict)
153
+ else:
154
+ device = self.device
155
+ dtype = self.dtype
156
+ load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
157
+
158
+ key_id += 1
159
+
160
+ return attn_procs
161
+
162
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
163
+ if not isinstance(state_dicts, list):
164
+ state_dicts = [state_dicts]
165
+
166
+ self.encoder_hid_proj = None
167
+
168
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
169
+ self.set_attn_processor(attn_procs)
170
+
171
+ image_projection_layers = []
172
+ for state_dict in state_dicts:
173
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
174
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
175
+ )
176
+ image_projection_layers.append(image_projection_layer)
177
+
178
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
179
+ self.config.encoder_hid_dim_type = "ip_image_proj"
180
+
181
+ self.to(dtype=self.dtype, device=self.device)
icedit/diffusers/loaders/transformer_sd3.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict
15
+
16
+ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
17
+ from ..models.embeddings import IPAdapterTimeImageProjection
18
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
19
+
20
+
21
+ class SD3Transformer2DLoadersMixin:
22
+ """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
23
+
24
+ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
25
+ """Sets IP-Adapter attention processors, image projection, and loads state_dict.
26
+
27
+ Args:
28
+ state_dict (`Dict`):
29
+ State dict with keys "ip_adapter", which contains parameters for attention processors, and
30
+ "image_proj", which contains parameters for image projection net.
31
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
32
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
33
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
34
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
35
+ argument to `True` will raise an error.
36
+ """
37
+ # IP-Adapter cross attention parameters
38
+ hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
39
+ ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
40
+ timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
41
+
42
+ # Dict where key is transformer layer index, value is attention processor's state dict
43
+ # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
44
+ layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
45
+ for key, weights in state_dict["ip_adapter"].items():
46
+ idx, name = key.split(".", maxsplit=1)
47
+ layer_state_dict[int(idx)][name] = weights
48
+
49
+ # Create IP-Adapter attention processor
50
+ attn_procs = {}
51
+ for idx, name in enumerate(self.attn_processors.keys()):
52
+ attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
53
+ hidden_size=hidden_size,
54
+ ip_hidden_states_dim=ip_hidden_states_dim,
55
+ head_dim=self.config.attention_head_dim,
56
+ timesteps_emb_dim=timesteps_emb_dim,
57
+ ).to(self.device, dtype=self.dtype)
58
+
59
+ if not low_cpu_mem_usage:
60
+ attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
61
+ else:
62
+ load_model_dict_into_meta(
63
+ attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
64
+ )
65
+
66
+ self.set_attn_processor(attn_procs)
67
+
68
+ # Image projetion parameters
69
+ embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
70
+ output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
71
+ hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
72
+ heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
73
+ num_queries = state_dict["image_proj"]["latents"].shape[1]
74
+ timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
75
+
76
+ # Image projection
77
+ self.image_proj = IPAdapterTimeImageProjection(
78
+ embed_dim=embed_dim,
79
+ output_dim=output_dim,
80
+ hidden_dim=hidden_dim,
81
+ heads=heads,
82
+ num_queries=num_queries,
83
+ timestep_in_dim=timestep_in_dim,
84
+ ).to(device=self.device, dtype=self.dtype)
85
+
86
+ if not low_cpu_mem_usage:
87
+ self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
88
+ else:
89
+ load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
icedit/diffusers/loaders/unet.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from collections import defaultdict
16
+ from contextlib import nullcontext
17
+ from pathlib import Path
18
+ from typing import Callable, Dict, Union
19
+
20
+ import safetensors
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from huggingface_hub.utils import validate_hf_hub_args
24
+
25
+ from ..models.embeddings import (
26
+ ImageProjection,
27
+ IPAdapterFaceIDImageProjection,
28
+ IPAdapterFaceIDPlusImageProjection,
29
+ IPAdapterFullImageProjection,
30
+ IPAdapterPlusImageProjection,
31
+ MultiIPAdapterImageProjection,
32
+ )
33
+ from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
34
+ from ..utils import (
35
+ USE_PEFT_BACKEND,
36
+ _get_model_file,
37
+ convert_unet_state_dict_to_peft,
38
+ deprecate,
39
+ get_adapter_name,
40
+ get_peft_kwargs,
41
+ is_accelerate_available,
42
+ is_peft_version,
43
+ is_torch_version,
44
+ logging,
45
+ )
46
+ from .lora_base import _func_optionally_disable_offloading
47
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
48
+ from .utils import AttnProcsLayers
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
55
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
56
+
57
+
58
+ class UNet2DConditionLoadersMixin:
59
+ """
60
+ Load LoRA layers into a [`UNet2DCondtionModel`].
61
+ """
62
+
63
+ text_encoder_name = TEXT_ENCODER_NAME
64
+ unet_name = UNET_NAME
65
+
66
+ @validate_hf_hub_args
67
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
68
+ r"""
69
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
70
+ defined in
71
+ [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
72
+ and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
73
+ `peft`: `pip install -U peft`.
74
+
75
+ Parameters:
76
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
77
+ Can be either:
78
+
79
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
80
+ the Hub.
81
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
82
+ with [`ModelMixin.save_pretrained`].
83
+ - A [torch state
84
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
85
+
86
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
87
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
88
+ is not used.
89
+ force_download (`bool`, *optional*, defaults to `False`):
90
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
91
+ cached versions if they exist.
92
+
93
+ proxies (`Dict[str, str]`, *optional*):
94
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
95
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
96
+ local_files_only (`bool`, *optional*, defaults to `False`):
97
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
98
+ won't be downloaded from the Hub.
99
+ token (`str` or *bool*, *optional*):
100
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
101
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
102
+ revision (`str`, *optional*, defaults to `"main"`):
103
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
104
+ allowed by Git.
105
+ subfolder (`str`, *optional*, defaults to `""`):
106
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
107
+ network_alphas (`Dict[str, float]`):
108
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
109
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
110
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
111
+ adapter_name (`str`, *optional*, defaults to None):
112
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
113
+ `default_{i}` where i is the total number of adapters being loaded.
114
+ weight_name (`str`, *optional*, defaults to None):
115
+ Name of the serialized state dict file.
116
+ low_cpu_mem_usage (`bool`, *optional*):
117
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
118
+ weights.
119
+
120
+ Example:
121
+
122
+ ```py
123
+ from diffusers import AutoPipelineForText2Image
124
+ import torch
125
+
126
+ pipeline = AutoPipelineForText2Image.from_pretrained(
127
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
128
+ ).to("cuda")
129
+ pipeline.unet.load_attn_procs(
130
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
131
+ )
132
+ ```
133
+ """
134
+ cache_dir = kwargs.pop("cache_dir", None)
135
+ force_download = kwargs.pop("force_download", False)
136
+ proxies = kwargs.pop("proxies", None)
137
+ local_files_only = kwargs.pop("local_files_only", None)
138
+ token = kwargs.pop("token", None)
139
+ revision = kwargs.pop("revision", None)
140
+ subfolder = kwargs.pop("subfolder", None)
141
+ weight_name = kwargs.pop("weight_name", None)
142
+ use_safetensors = kwargs.pop("use_safetensors", None)
143
+ adapter_name = kwargs.pop("adapter_name", None)
144
+ _pipeline = kwargs.pop("_pipeline", None)
145
+ network_alphas = kwargs.pop("network_alphas", None)
146
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
147
+ allow_pickle = False
148
+
149
+ if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
150
+ raise ValueError(
151
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
152
+ )
153
+
154
+ if use_safetensors is None:
155
+ use_safetensors = True
156
+ allow_pickle = True
157
+
158
+ user_agent = {
159
+ "file_type": "attn_procs_weights",
160
+ "framework": "pytorch",
161
+ }
162
+
163
+ model_file = None
164
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
165
+ # Let's first try to load .safetensors weights
166
+ if (use_safetensors and weight_name is None) or (
167
+ weight_name is not None and weight_name.endswith(".safetensors")
168
+ ):
169
+ try:
170
+ model_file = _get_model_file(
171
+ pretrained_model_name_or_path_or_dict,
172
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
173
+ cache_dir=cache_dir,
174
+ force_download=force_download,
175
+ proxies=proxies,
176
+ local_files_only=local_files_only,
177
+ token=token,
178
+ revision=revision,
179
+ subfolder=subfolder,
180
+ user_agent=user_agent,
181
+ )
182
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
183
+ except IOError as e:
184
+ if not allow_pickle:
185
+ raise e
186
+ # try loading non-safetensors weights
187
+ pass
188
+ if model_file is None:
189
+ model_file = _get_model_file(
190
+ pretrained_model_name_or_path_or_dict,
191
+ weights_name=weight_name or LORA_WEIGHT_NAME,
192
+ cache_dir=cache_dir,
193
+ force_download=force_download,
194
+ proxies=proxies,
195
+ local_files_only=local_files_only,
196
+ token=token,
197
+ revision=revision,
198
+ subfolder=subfolder,
199
+ user_agent=user_agent,
200
+ )
201
+ state_dict = load_state_dict(model_file)
202
+ else:
203
+ state_dict = pretrained_model_name_or_path_or_dict
204
+
205
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
206
+ is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
207
+ is_model_cpu_offload = False
208
+ is_sequential_cpu_offload = False
209
+
210
+ if is_lora:
211
+ deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
212
+ deprecate("load_attn_procs", "0.40.0", deprecation_message)
213
+
214
+ if is_custom_diffusion:
215
+ attn_processors = self._process_custom_diffusion(state_dict=state_dict)
216
+ elif is_lora:
217
+ is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
218
+ state_dict=state_dict,
219
+ unet_identifier_key=self.unet_name,
220
+ network_alphas=network_alphas,
221
+ adapter_name=adapter_name,
222
+ _pipeline=_pipeline,
223
+ low_cpu_mem_usage=low_cpu_mem_usage,
224
+ )
225
+ else:
226
+ raise ValueError(
227
+ f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
228
+ )
229
+
230
+ # <Unsafe code
231
+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
232
+ # Now we remove any existing hooks to `_pipeline`.
233
+
234
+ # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
235
+ if is_custom_diffusion and _pipeline is not None:
236
+ is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
237
+
238
+ # only custom diffusion needs to set attn processors
239
+ self.set_attn_processor(attn_processors)
240
+ self.to(dtype=self.dtype, device=self.device)
241
+
242
+ # Offload back.
243
+ if is_model_cpu_offload:
244
+ _pipeline.enable_model_cpu_offload()
245
+ elif is_sequential_cpu_offload:
246
+ _pipeline.enable_sequential_cpu_offload()
247
+ # Unsafe code />
248
+
249
+ def _process_custom_diffusion(self, state_dict):
250
+ from ..models.attention_processor import CustomDiffusionAttnProcessor
251
+
252
+ attn_processors = {}
253
+ custom_diffusion_grouped_dict = defaultdict(dict)
254
+ for key, value in state_dict.items():
255
+ if len(value) == 0:
256
+ custom_diffusion_grouped_dict[key] = {}
257
+ else:
258
+ if "to_out" in key:
259
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
260
+ else:
261
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
262
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
263
+
264
+ for key, value_dict in custom_diffusion_grouped_dict.items():
265
+ if len(value_dict) == 0:
266
+ attn_processors[key] = CustomDiffusionAttnProcessor(
267
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
268
+ )
269
+ else:
270
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
271
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
272
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
273
+ attn_processors[key] = CustomDiffusionAttnProcessor(
274
+ train_kv=True,
275
+ train_q_out=train_q_out,
276
+ hidden_size=hidden_size,
277
+ cross_attention_dim=cross_attention_dim,
278
+ )
279
+ attn_processors[key].load_state_dict(value_dict)
280
+
281
+ return attn_processors
282
+
283
+ def _process_lora(
284
+ self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
285
+ ):
286
+ # This method does the following things:
287
+ # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
288
+ # format. For legacy format no filtering is applied.
289
+ # 2. Converts the `state_dict` to the `peft` compatible format.
290
+ # 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
291
+ # `LoraConfig` specs.
292
+ # 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
293
+ if not USE_PEFT_BACKEND:
294
+ raise ValueError("PEFT backend is required for this method.")
295
+
296
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
297
+
298
+ keys = list(state_dict.keys())
299
+
300
+ unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
301
+ unet_state_dict = {
302
+ k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
303
+ }
304
+
305
+ if network_alphas is not None:
306
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
307
+ network_alphas = {
308
+ k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
309
+ }
310
+
311
+ is_model_cpu_offload = False
312
+ is_sequential_cpu_offload = False
313
+ state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
314
+
315
+ if len(state_dict_to_be_used) > 0:
316
+ if adapter_name in getattr(self, "peft_config", {}):
317
+ raise ValueError(
318
+ f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
319
+ )
320
+
321
+ state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
322
+
323
+ if network_alphas is not None:
324
+ # The alphas state dict have the same structure as Unet, thus we convert it to peft format using
325
+ # `convert_unet_state_dict_to_peft` method.
326
+ network_alphas = convert_unet_state_dict_to_peft(network_alphas)
327
+
328
+ rank = {}
329
+ for key, val in state_dict.items():
330
+ if "lora_B" in key:
331
+ rank[key] = val.shape[1]
332
+
333
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
334
+ if "use_dora" in lora_config_kwargs:
335
+ if lora_config_kwargs["use_dora"]:
336
+ if is_peft_version("<", "0.9.0"):
337
+ raise ValueError(
338
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
339
+ )
340
+ else:
341
+ if is_peft_version("<", "0.9.0"):
342
+ lora_config_kwargs.pop("use_dora")
343
+ lora_config = LoraConfig(**lora_config_kwargs)
344
+
345
+ # adapter_name
346
+ if adapter_name is None:
347
+ adapter_name = get_adapter_name(self)
348
+
349
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
350
+ # otherwise loading LoRA weights will lead to an error
351
+ is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
352
+ peft_kwargs = {}
353
+ if is_peft_version(">=", "0.13.1"):
354
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
355
+
356
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
357
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
358
+
359
+ warn_msg = ""
360
+ if incompatible_keys is not None:
361
+ # Check only for unexpected keys.
362
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
363
+ if unexpected_keys:
364
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
365
+ if lora_unexpected_keys:
366
+ warn_msg = (
367
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
368
+ f" {', '.join(lora_unexpected_keys)}. "
369
+ )
370
+
371
+ # Filter missing keys specific to the current adapter.
372
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
373
+ if missing_keys:
374
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
375
+ if lora_missing_keys:
376
+ warn_msg += (
377
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
378
+ f" {', '.join(lora_missing_keys)}."
379
+ )
380
+
381
+ if warn_msg:
382
+ logger.warning(warn_msg)
383
+
384
+ return is_model_cpu_offload, is_sequential_cpu_offload
385
+
386
+ @classmethod
387
+ # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
388
+ def _optionally_disable_offloading(cls, _pipeline):
389
+ """
390
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
391
+
392
+ Args:
393
+ _pipeline (`DiffusionPipeline`):
394
+ The pipeline to disable offloading for.
395
+
396
+ Returns:
397
+ tuple:
398
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
399
+ """
400
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
401
+
402
+ def save_attn_procs(
403
+ self,
404
+ save_directory: Union[str, os.PathLike],
405
+ is_main_process: bool = True,
406
+ weight_name: str = None,
407
+ save_function: Callable = None,
408
+ safe_serialization: bool = True,
409
+ **kwargs,
410
+ ):
411
+ r"""
412
+ Save attention processor layers to a directory so that it can be reloaded with the
413
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
414
+
415
+ Arguments:
416
+ save_directory (`str` or `os.PathLike`):
417
+ Directory to save an attention processor to (will be created if it doesn't exist).
418
+ is_main_process (`bool`, *optional*, defaults to `True`):
419
+ Whether the process calling this is the main process or not. Useful during distributed training and you
420
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
421
+ process to avoid race conditions.
422
+ save_function (`Callable`):
423
+ The function to use to save the state dictionary. Useful during distributed training when you need to
424
+ replace `torch.save` with another method. Can be configured with the environment variable
425
+ `DIFFUSERS_SAVE_MODE`.
426
+ safe_serialization (`bool`, *optional*, defaults to `True`):
427
+ Whether to save the model using `safetensors` or with `pickle`.
428
+
429
+ Example:
430
+
431
+ ```py
432
+ import torch
433
+ from diffusers import DiffusionPipeline
434
+
435
+ pipeline = DiffusionPipeline.from_pretrained(
436
+ "CompVis/stable-diffusion-v1-4",
437
+ torch_dtype=torch.float16,
438
+ ).to("cuda")
439
+ pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
440
+ pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
441
+ ```
442
+ """
443
+ from ..models.attention_processor import (
444
+ CustomDiffusionAttnProcessor,
445
+ CustomDiffusionAttnProcessor2_0,
446
+ CustomDiffusionXFormersAttnProcessor,
447
+ )
448
+
449
+ if os.path.isfile(save_directory):
450
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
451
+ return
452
+
453
+ is_custom_diffusion = any(
454
+ isinstance(
455
+ x,
456
+ (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
457
+ )
458
+ for (_, x) in self.attn_processors.items()
459
+ )
460
+ if is_custom_diffusion:
461
+ state_dict = self._get_custom_diffusion_state_dict()
462
+ if save_function is None and safe_serialization:
463
+ # safetensors does not support saving dicts with non-tensor values
464
+ empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)}
465
+ if len(empty_state_dict) > 0:
466
+ logger.warning(
467
+ f"Safetensors does not support saving dicts with non-tensor values. "
468
+ f"The following keys will be ignored: {empty_state_dict.keys()}"
469
+ )
470
+ state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
471
+ else:
472
+ deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
473
+ deprecate("save_attn_procs", "0.40.0", deprecation_message)
474
+
475
+ if not USE_PEFT_BACKEND:
476
+ raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
477
+
478
+ from peft.utils import get_peft_model_state_dict
479
+
480
+ state_dict = get_peft_model_state_dict(self)
481
+
482
+ if save_function is None:
483
+ if safe_serialization:
484
+
485
+ def save_function(weights, filename):
486
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
487
+
488
+ else:
489
+ save_function = torch.save
490
+
491
+ os.makedirs(save_directory, exist_ok=True)
492
+
493
+ if weight_name is None:
494
+ if safe_serialization:
495
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
496
+ else:
497
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
498
+
499
+ # Save the model
500
+ save_path = Path(save_directory, weight_name).as_posix()
501
+ save_function(state_dict, save_path)
502
+ logger.info(f"Model weights saved in {save_path}")
503
+
504
+ def _get_custom_diffusion_state_dict(self):
505
+ from ..models.attention_processor import (
506
+ CustomDiffusionAttnProcessor,
507
+ CustomDiffusionAttnProcessor2_0,
508
+ CustomDiffusionXFormersAttnProcessor,
509
+ )
510
+
511
+ model_to_save = AttnProcsLayers(
512
+ {
513
+ y: x
514
+ for (y, x) in self.attn_processors.items()
515
+ if isinstance(
516
+ x,
517
+ (
518
+ CustomDiffusionAttnProcessor,
519
+ CustomDiffusionAttnProcessor2_0,
520
+ CustomDiffusionXFormersAttnProcessor,
521
+ ),
522
+ )
523
+ }
524
+ )
525
+ state_dict = model_to_save.state_dict()
526
+ for name, attn in self.attn_processors.items():
527
+ if len(attn.state_dict()) == 0:
528
+ state_dict[name] = {}
529
+
530
+ return state_dict
531
+
532
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
533
+ if low_cpu_mem_usage:
534
+ if is_accelerate_available():
535
+ from accelerate import init_empty_weights
536
+
537
+ else:
538
+ low_cpu_mem_usage = False
539
+ logger.warning(
540
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
541
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
542
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
543
+ " install accelerate\n```\n."
544
+ )
545
+
546
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
547
+ raise NotImplementedError(
548
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
549
+ " `low_cpu_mem_usage=False`."
550
+ )
551
+
552
+ updated_state_dict = {}
553
+ image_projection = None
554
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
555
+
556
+ if "proj.weight" in state_dict:
557
+ # IP-Adapter
558
+ num_image_text_embeds = 4
559
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
560
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
561
+
562
+ with init_context():
563
+ image_projection = ImageProjection(
564
+ cross_attention_dim=cross_attention_dim,
565
+ image_embed_dim=clip_embeddings_dim,
566
+ num_image_text_embeds=num_image_text_embeds,
567
+ )
568
+
569
+ for key, value in state_dict.items():
570
+ diffusers_name = key.replace("proj", "image_embeds")
571
+ updated_state_dict[diffusers_name] = value
572
+
573
+ elif "proj.3.weight" in state_dict:
574
+ # IP-Adapter Full
575
+ clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
576
+ cross_attention_dim = state_dict["proj.3.weight"].shape[0]
577
+
578
+ with init_context():
579
+ image_projection = IPAdapterFullImageProjection(
580
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
581
+ )
582
+
583
+ for key, value in state_dict.items():
584
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
585
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
586
+ diffusers_name = diffusers_name.replace("proj.3", "norm")
587
+ updated_state_dict[diffusers_name] = value
588
+
589
+ elif "perceiver_resampler.proj_in.weight" in state_dict:
590
+ # IP-Adapter Face ID Plus
591
+ id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
592
+ embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
593
+ hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
594
+ output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
595
+ heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64
596
+
597
+ with init_context():
598
+ image_projection = IPAdapterFaceIDPlusImageProjection(
599
+ embed_dims=embed_dims,
600
+ output_dims=output_dims,
601
+ hidden_dims=hidden_dims,
602
+ heads=heads,
603
+ id_embeddings_dim=id_embeddings_dim,
604
+ )
605
+
606
+ for key, value in state_dict.items():
607
+ diffusers_name = key.replace("perceiver_resampler.", "")
608
+ diffusers_name = diffusers_name.replace("0.to", "attn.to")
609
+ diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
610
+ diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
611
+ diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
612
+ diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
613
+ diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
614
+ diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
615
+ diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
616
+ diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
617
+ diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
618
+ diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
619
+ diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
620
+ diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
621
+ diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
622
+ diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
623
+ diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
624
+ diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
625
+ diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
626
+ diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
627
+ diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
628
+ diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")
629
+
630
+ if "norm1" in diffusers_name:
631
+ updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
632
+ elif "norm2" in diffusers_name:
633
+ updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
634
+ elif "to_kv" in diffusers_name:
635
+ v_chunk = value.chunk(2, dim=0)
636
+ updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
637
+ updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
638
+ elif "to_out" in diffusers_name:
639
+ updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
640
+ elif "proj.0.weight" == diffusers_name:
641
+ updated_state_dict["proj.net.0.proj.weight"] = value
642
+ elif "proj.0.bias" == diffusers_name:
643
+ updated_state_dict["proj.net.0.proj.bias"] = value
644
+ elif "proj.2.weight" == diffusers_name:
645
+ updated_state_dict["proj.net.2.weight"] = value
646
+ elif "proj.2.bias" == diffusers_name:
647
+ updated_state_dict["proj.net.2.bias"] = value
648
+ else:
649
+ updated_state_dict[diffusers_name] = value
650
+
651
+ elif "norm.weight" in state_dict:
652
+ # IP-Adapter Face ID
653
+ id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
654
+ id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
655
+ multiplier = id_embeddings_dim_out // id_embeddings_dim_in
656
+ norm_layer = "norm.weight"
657
+ cross_attention_dim = state_dict[norm_layer].shape[0]
658
+ num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
659
+
660
+ with init_context():
661
+ image_projection = IPAdapterFaceIDImageProjection(
662
+ cross_attention_dim=cross_attention_dim,
663
+ image_embed_dim=id_embeddings_dim_in,
664
+ mult=multiplier,
665
+ num_tokens=num_tokens,
666
+ )
667
+
668
+ for key, value in state_dict.items():
669
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
670
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
671
+ updated_state_dict[diffusers_name] = value
672
+
673
+ else:
674
+ # IP-Adapter Plus
675
+ num_image_text_embeds = state_dict["latents"].shape[1]
676
+ embed_dims = state_dict["proj_in.weight"].shape[1]
677
+ output_dims = state_dict["proj_out.weight"].shape[0]
678
+ hidden_dims = state_dict["latents"].shape[2]
679
+ attn_key_present = any("attn" in k for k in state_dict)
680
+ heads = (
681
+ state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
682
+ if attn_key_present
683
+ else state_dict["layers.0.0.to_q.weight"].shape[0] // 64
684
+ )
685
+
686
+ with init_context():
687
+ image_projection = IPAdapterPlusImageProjection(
688
+ embed_dims=embed_dims,
689
+ output_dims=output_dims,
690
+ hidden_dims=hidden_dims,
691
+ heads=heads,
692
+ num_queries=num_image_text_embeds,
693
+ )
694
+
695
+ for key, value in state_dict.items():
696
+ diffusers_name = key.replace("0.to", "2.to")
697
+
698
+ diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0")
699
+ diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1")
700
+ diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0")
701
+ diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1")
702
+ diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0")
703
+ diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1")
704
+ diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0")
705
+ diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1")
706
+
707
+ if "to_kv" in diffusers_name:
708
+ parts = diffusers_name.split(".")
709
+ parts[2] = "attn"
710
+ diffusers_name = ".".join(parts)
711
+ v_chunk = value.chunk(2, dim=0)
712
+ updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
713
+ updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
714
+ elif "to_q" in diffusers_name:
715
+ parts = diffusers_name.split(".")
716
+ parts[2] = "attn"
717
+ diffusers_name = ".".join(parts)
718
+ updated_state_dict[diffusers_name] = value
719
+ elif "to_out" in diffusers_name:
720
+ parts = diffusers_name.split(".")
721
+ parts[2] = "attn"
722
+ diffusers_name = ".".join(parts)
723
+ updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
724
+ else:
725
+ diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0")
726
+ diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj")
727
+ diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2")
728
+
729
+ diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0")
730
+ diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj")
731
+ diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2")
732
+
733
+ diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0")
734
+ diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj")
735
+ diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2")
736
+
737
+ diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0")
738
+ diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj")
739
+ diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2")
740
+ updated_state_dict[diffusers_name] = value
741
+
742
+ if not low_cpu_mem_usage:
743
+ image_projection.load_state_dict(updated_state_dict, strict=True)
744
+ else:
745
+ load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
746
+
747
+ return image_projection
748
+
749
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
750
+ from ..models.attention_processor import (
751
+ IPAdapterAttnProcessor,
752
+ IPAdapterAttnProcessor2_0,
753
+ IPAdapterXFormersAttnProcessor,
754
+ )
755
+
756
+ if low_cpu_mem_usage:
757
+ if is_accelerate_available():
758
+ from accelerate import init_empty_weights
759
+
760
+ else:
761
+ low_cpu_mem_usage = False
762
+ logger.warning(
763
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
764
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
765
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
766
+ " install accelerate\n```\n."
767
+ )
768
+
769
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
770
+ raise NotImplementedError(
771
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
772
+ " `low_cpu_mem_usage=False`."
773
+ )
774
+
775
+ # set ip-adapter cross-attention processors & load state_dict
776
+ attn_procs = {}
777
+ key_id = 1
778
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
779
+ for name in self.attn_processors.keys():
780
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
781
+ if name.startswith("mid_block"):
782
+ hidden_size = self.config.block_out_channels[-1]
783
+ elif name.startswith("up_blocks"):
784
+ block_id = int(name[len("up_blocks.")])
785
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
786
+ elif name.startswith("down_blocks"):
787
+ block_id = int(name[len("down_blocks.")])
788
+ hidden_size = self.config.block_out_channels[block_id]
789
+
790
+ if cross_attention_dim is None or "motion_modules" in name:
791
+ attn_processor_class = self.attn_processors[name].__class__
792
+ attn_procs[name] = attn_processor_class()
793
+ else:
794
+ if "XFormers" in str(self.attn_processors[name].__class__):
795
+ attn_processor_class = IPAdapterXFormersAttnProcessor
796
+ else:
797
+ attn_processor_class = (
798
+ IPAdapterAttnProcessor2_0
799
+ if hasattr(F, "scaled_dot_product_attention")
800
+ else IPAdapterAttnProcessor
801
+ )
802
+ num_image_text_embeds = []
803
+ for state_dict in state_dicts:
804
+ if "proj.weight" in state_dict["image_proj"]:
805
+ # IP-Adapter
806
+ num_image_text_embeds += [4]
807
+ elif "proj.3.weight" in state_dict["image_proj"]:
808
+ # IP-Adapter Full Face
809
+ num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
810
+ elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
811
+ # IP-Adapter Face ID Plus
812
+ num_image_text_embeds += [4]
813
+ elif "norm.weight" in state_dict["image_proj"]:
814
+ # IP-Adapter Face ID
815
+ num_image_text_embeds += [4]
816
+ else:
817
+ # IP-Adapter Plus
818
+ num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
819
+
820
+ with init_context():
821
+ attn_procs[name] = attn_processor_class(
822
+ hidden_size=hidden_size,
823
+ cross_attention_dim=cross_attention_dim,
824
+ scale=1.0,
825
+ num_tokens=num_image_text_embeds,
826
+ )
827
+
828
+ value_dict = {}
829
+ for i, state_dict in enumerate(state_dicts):
830
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
831
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
832
+
833
+ if not low_cpu_mem_usage:
834
+ attn_procs[name].load_state_dict(value_dict)
835
+ else:
836
+ device = next(iter(value_dict.values())).device
837
+ dtype = next(iter(value_dict.values())).dtype
838
+ load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
839
+
840
+ key_id += 2
841
+
842
+ return attn_procs
843
+
844
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
845
+ if not isinstance(state_dicts, list):
846
+ state_dicts = [state_dicts]
847
+
848
+ # Kolors Unet already has a `encoder_hid_proj`
849
+ if (
850
+ self.encoder_hid_proj is not None
851
+ and self.config.encoder_hid_dim_type == "text_proj"
852
+ and not hasattr(self, "text_encoder_hid_proj")
853
+ ):
854
+ self.text_encoder_hid_proj = self.encoder_hid_proj
855
+
856
+ # Set encoder_hid_proj after loading ip_adapter weights,
857
+ # because `IPAdapterPlusImageProjection` also has `attn_processors`.
858
+ self.encoder_hid_proj = None
859
+
860
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
861
+ self.set_attn_processor(attn_procs)
862
+
863
+ # convert IP-Adapter Image Projection layers to diffusers
864
+ image_projection_layers = []
865
+ for state_dict in state_dicts:
866
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
867
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
868
+ )
869
+ image_projection_layers.append(image_projection_layer)
870
+
871
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
872
+ self.config.encoder_hid_dim_type = "ip_image_proj"
873
+
874
+ self.to(dtype=self.dtype, device=self.device)
875
+
876
+ def _load_ip_adapter_loras(self, state_dicts):
877
+ lora_dicts = {}
878
+ for key_id, name in enumerate(self.attn_processors.keys()):
879
+ for i, state_dict in enumerate(state_dicts):
880
+ if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
881
+ if i not in lora_dicts:
882
+ lora_dicts[i] = {}
883
+ lora_dicts[i].update(
884
+ {
885
+ f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
886
+ f"{key_id}.to_k_lora.down.weight"
887
+ ]
888
+ }
889
+ )
890
+ lora_dicts[i].update(
891
+ {
892
+ f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
893
+ f"{key_id}.to_q_lora.down.weight"
894
+ ]
895
+ }
896
+ )
897
+ lora_dicts[i].update(
898
+ {
899
+ f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
900
+ f"{key_id}.to_v_lora.down.weight"
901
+ ]
902
+ }
903
+ )
904
+ lora_dicts[i].update(
905
+ {
906
+ f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
907
+ f"{key_id}.to_out_lora.down.weight"
908
+ ]
909
+ }
910
+ )
911
+ lora_dicts[i].update(
912
+ {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
913
+ )
914
+ lora_dicts[i].update(
915
+ {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
916
+ )
917
+ lora_dicts[i].update(
918
+ {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
919
+ )
920
+ lora_dicts[i].update(
921
+ {
922
+ f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
923
+ f"{key_id}.to_out_lora.up.weight"
924
+ ]
925
+ }
926
+ )
927
+ return lora_dicts
icedit/diffusers/loaders/unet_loader_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+ from typing import TYPE_CHECKING, Dict, List, Union
16
+
17
+ from ..utils import logging
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ # import here to avoid circular imports
22
+ from ..models import UNet2DConditionModel
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ def _translate_into_actual_layer_name(name):
28
+ """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
29
+ if name == "mid":
30
+ return "mid_block.attentions.0"
31
+
32
+ updown, block, attn = name.split(".")
33
+
34
+ updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
35
+ block = block.replace("block_", "")
36
+ attn = "attentions." + attn
37
+
38
+ return ".".join((updown, block, attn))
39
+
40
+
41
+ def _maybe_expand_lora_scales(
42
+ unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
43
+ ):
44
+ blocks_with_transformer = {
45
+ "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
46
+ "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
47
+ }
48
+ transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
49
+
50
+ expanded_weight_scales = [
51
+ _maybe_expand_lora_scales_for_one_adapter(
52
+ weight_for_adapter,
53
+ blocks_with_transformer,
54
+ transformer_per_block,
55
+ unet.state_dict(),
56
+ default_scale=default_scale,
57
+ )
58
+ for weight_for_adapter in weight_scales
59
+ ]
60
+
61
+ return expanded_weight_scales
62
+
63
+
64
+ def _maybe_expand_lora_scales_for_one_adapter(
65
+ scales: Union[float, Dict],
66
+ blocks_with_transformer: Dict[str, int],
67
+ transformer_per_block: Dict[str, int],
68
+ state_dict: None,
69
+ default_scale: float = 1.0,
70
+ ):
71
+ """
72
+ Expands the inputs into a more granular dictionary. See the example below for more details.
73
+
74
+ Parameters:
75
+ scales (`Union[float, Dict]`):
76
+ Scales dict to expand.
77
+ blocks_with_transformer (`Dict[str, int]`):
78
+ Dict with keys 'up' and 'down', showing which blocks have transformer layers
79
+ transformer_per_block (`Dict[str, int]`):
80
+ Dict with keys 'up' and 'down', showing how many transformer layers each block has
81
+
82
+ E.g. turns
83
+ ```python
84
+ scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
85
+ blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
86
+ transformer_per_block = {"down": 2, "up": 3}
87
+ ```
88
+ into
89
+ ```python
90
+ {
91
+ "down.block_1.0": 2,
92
+ "down.block_1.1": 2,
93
+ "down.block_2.0": 2,
94
+ "down.block_2.1": 2,
95
+ "mid": 3,
96
+ "up.block_0.0": 4,
97
+ "up.block_0.1": 4,
98
+ "up.block_0.2": 4,
99
+ "up.block_1.0": 5,
100
+ "up.block_1.1": 6,
101
+ "up.block_1.2": 7,
102
+ }
103
+ ```
104
+ """
105
+ if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
106
+ raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
107
+
108
+ if sorted(transformer_per_block.keys()) != ["down", "up"]:
109
+ raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
110
+
111
+ if not isinstance(scales, dict):
112
+ # don't expand if scales is a single number
113
+ return scales
114
+
115
+ scales = copy.deepcopy(scales)
116
+
117
+ if "mid" not in scales:
118
+ scales["mid"] = default_scale
119
+ elif isinstance(scales["mid"], list):
120
+ if len(scales["mid"]) == 1:
121
+ scales["mid"] = scales["mid"][0]
122
+ else:
123
+ raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
124
+
125
+ for updown in ["up", "down"]:
126
+ if updown not in scales:
127
+ scales[updown] = default_scale
128
+
129
+ # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
130
+ if not isinstance(scales[updown], dict):
131
+ scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
132
+
133
+ # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
134
+ for i in blocks_with_transformer[updown]:
135
+ block = f"block_{i}"
136
+ # set not assigned blocks to default scale
137
+ if block not in scales[updown]:
138
+ scales[updown][block] = default_scale
139
+ if not isinstance(scales[updown][block], list):
140
+ scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
141
+ elif len(scales[updown][block]) == 1:
142
+ # a list specifying scale to each masked IP input
143
+ scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
144
+ elif len(scales[updown][block]) != transformer_per_block[updown]:
145
+ raise ValueError(
146
+ f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
147
+ )
148
+
149
+ # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
150
+ for i in blocks_with_transformer[updown]:
151
+ block = f"block_{i}"
152
+ for tf_idx, value in enumerate(scales[updown][block]):
153
+ scales[f"{updown}.{block}.{tf_idx}"] = value
154
+
155
+ del scales[updown]
156
+
157
+ for layer in scales.keys():
158
+ if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
159
+ raise ValueError(
160
+ f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
161
+ )
162
+
163
+ return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}
icedit/diffusers/loaders/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict
16
+
17
+ import torch
18
+
19
+
20
+ class AttnProcsLayers(torch.nn.Module):
21
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
22
+ super().__init__()
23
+ self.layers = torch.nn.ModuleList(state_dict.values())
24
+ self.mapping = dict(enumerate(state_dict.keys()))
25
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
26
+
27
+ # .processor for unet, .self_attn for text encoder
28
+ self.split_keys = [".processor", ".self_attn"]
29
+
30
+ # we add a hook to state_dict() and load_state_dict() so that the
31
+ # naming fits with `unet.attn_processors`
32
+ def map_to(module, state_dict, *args, **kwargs):
33
+ new_state_dict = {}
34
+ for key, value in state_dict.items():
35
+ num = int(key.split(".")[1]) # 0 is always "layers"
36
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
37
+ new_state_dict[new_key] = value
38
+
39
+ return new_state_dict
40
+
41
+ def remap_key(key, state_dict):
42
+ for k in self.split_keys:
43
+ if k in key:
44
+ return key.split(k)[0] + k
45
+
46
+ raise ValueError(
47
+ f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
48
+ )
49
+
50
+ def map_from(module, state_dict, *args, **kwargs):
51
+ all_keys = list(state_dict.keys())
52
+ for key in all_keys:
53
+ replace_key = remap_key(key, state_dict)
54
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
55
+ state_dict[new_key] = state_dict[key]
56
+ del state_dict[key]
57
+
58
+ self._register_state_dict_hook(map_to)
59
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
icedit/diffusers/models/__init__.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import (
18
+ DIFFUSERS_SLOW_IMPORT,
19
+ _LazyModule,
20
+ is_flax_available,
21
+ is_torch_available,
22
+ )
23
+
24
+
25
+ _import_structure = {}
26
+
27
+ if is_torch_available():
28
+ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
+ _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
+ _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
31
+ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
32
+ _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
33
+ _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
34
+ _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
35
+ _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
36
+ _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
37
+ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
38
+ _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
39
+ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
40
+ _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
41
+ _import_structure["autoencoders.vq_model"] = ["VQModel"]
42
+ _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
43
+ _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
44
+ _import_structure["controlnets.controlnet_hunyuan"] = [
45
+ "HunyuanDiT2DControlNetModel",
46
+ "HunyuanDiT2DMultiControlNetModel",
47
+ ]
48
+ _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
49
+ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
50
+ _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
51
+ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
52
+ _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
53
+ _import_structure["embeddings"] = ["ImageProjection"]
54
+ _import_structure["modeling_utils"] = ["ModelMixin"]
55
+ _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
56
+ _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
57
+ _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
58
+ _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
59
+ _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
60
+ _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
61
+ _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
62
+ _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
63
+ _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
64
+ _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
65
+ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
66
+ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
67
+ _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
68
+ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
69
+ _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
70
+ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
71
+ _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
72
+ _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
73
+ _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
74
+ _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
75
+ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
76
+ _import_structure["unets.unet_1d"] = ["UNet1DModel"]
77
+ _import_structure["unets.unet_2d"] = ["UNet2DModel"]
78
+ _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
79
+ _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
80
+ _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"]
81
+ _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
82
+ _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
83
+ _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
84
+ _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
85
+ _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
86
+
87
+ if is_flax_available():
88
+ _import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
89
+ _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
90
+ _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
91
+
92
+
93
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
94
+ if is_torch_available():
95
+ from .adapter import MultiAdapter, T2IAdapter
96
+ from .autoencoders import (
97
+ AsymmetricAutoencoderKL,
98
+ AutoencoderDC,
99
+ AutoencoderKL,
100
+ AutoencoderKLAllegro,
101
+ AutoencoderKLCogVideoX,
102
+ AutoencoderKLHunyuanVideo,
103
+ AutoencoderKLLTXVideo,
104
+ AutoencoderKLMochi,
105
+ AutoencoderKLTemporalDecoder,
106
+ AutoencoderOobleck,
107
+ AutoencoderTiny,
108
+ ConsistencyDecoderVAE,
109
+ VQModel,
110
+ )
111
+ from .controlnets import (
112
+ ControlNetModel,
113
+ ControlNetUnionModel,
114
+ ControlNetXSAdapter,
115
+ FluxControlNetModel,
116
+ FluxMultiControlNetModel,
117
+ HunyuanDiT2DControlNetModel,
118
+ HunyuanDiT2DMultiControlNetModel,
119
+ MultiControlNetModel,
120
+ SD3ControlNetModel,
121
+ SD3MultiControlNetModel,
122
+ SparseControlNetModel,
123
+ UNetControlNetXSModel,
124
+ )
125
+ from .embeddings import ImageProjection
126
+ from .modeling_utils import ModelMixin
127
+ from .transformers import (
128
+ AllegroTransformer3DModel,
129
+ AuraFlowTransformer2DModel,
130
+ CogVideoXTransformer3DModel,
131
+ CogView3PlusTransformer2DModel,
132
+ DiTTransformer2DModel,
133
+ DualTransformer2DModel,
134
+ FluxTransformer2DModel,
135
+ HunyuanDiT2DModel,
136
+ HunyuanVideoTransformer3DModel,
137
+ LatteTransformer3DModel,
138
+ LTXVideoTransformer3DModel,
139
+ LuminaNextDiT2DModel,
140
+ MochiTransformer3DModel,
141
+ PixArtTransformer2DModel,
142
+ PriorTransformer,
143
+ SanaTransformer2DModel,
144
+ SD3Transformer2DModel,
145
+ StableAudioDiTModel,
146
+ T5FilmDecoder,
147
+ Transformer2DModel,
148
+ TransformerTemporalModel,
149
+ )
150
+ from .unets import (
151
+ I2VGenXLUNet,
152
+ Kandinsky3UNet,
153
+ MotionAdapter,
154
+ StableCascadeUNet,
155
+ UNet1DModel,
156
+ UNet2DConditionModel,
157
+ UNet2DModel,
158
+ UNet3DConditionModel,
159
+ UNetMotionModel,
160
+ UNetSpatioTemporalConditionModel,
161
+ UVit2DModel,
162
+ )
163
+
164
+ if is_flax_available():
165
+ from .controlnets import FlaxControlNetModel
166
+ from .unets import FlaxUNet2DConditionModel
167
+ from .vae_flax import FlaxAutoencoderKL
168
+
169
+ else:
170
+ import sys
171
+
172
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
icedit/diffusers/models/activations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import deprecate
21
+ from ..utils.import_utils import is_torch_npu_available, is_torch_version
22
+
23
+
24
+ if is_torch_npu_available():
25
+ import torch_npu
26
+
27
+ ACTIVATION_FUNCTIONS = {
28
+ "swish": nn.SiLU(),
29
+ "silu": nn.SiLU(),
30
+ "mish": nn.Mish(),
31
+ "gelu": nn.GELU(),
32
+ "relu": nn.ReLU(),
33
+ }
34
+
35
+
36
+ def get_activation(act_fn: str) -> nn.Module:
37
+ """Helper function to get activation function from string.
38
+
39
+ Args:
40
+ act_fn (str): Name of activation function.
41
+
42
+ Returns:
43
+ nn.Module: Activation function.
44
+ """
45
+
46
+ act_fn = act_fn.lower()
47
+ if act_fn in ACTIVATION_FUNCTIONS:
48
+ return ACTIVATION_FUNCTIONS[act_fn]
49
+ else:
50
+ raise ValueError(f"Unsupported activation function: {act_fn}")
51
+
52
+
53
+ class FP32SiLU(nn.Module):
54
+ r"""
55
+ SiLU activation function with input upcasted to torch.float32.
56
+ """
57
+
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
62
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
63
+
64
+
65
+ class GELU(nn.Module):
66
+ r"""
67
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
68
+
69
+ Parameters:
70
+ dim_in (`int`): The number of channels in the input.
71
+ dim_out (`int`): The number of channels in the output.
72
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
73
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
74
+ """
75
+
76
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
77
+ super().__init__()
78
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
79
+ self.approximate = approximate
80
+
81
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82
+ if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
83
+ # fp16 gelu not supported on mps before torch 2.0
84
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
85
+ return F.gelu(gate, approximate=self.approximate)
86
+
87
+ def forward(self, hidden_states):
88
+ hidden_states = self.proj(hidden_states)
89
+ hidden_states = self.gelu(hidden_states)
90
+ return hidden_states
91
+
92
+
93
+ class GEGLU(nn.Module):
94
+ r"""
95
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
96
+
97
+ Parameters:
98
+ dim_in (`int`): The number of channels in the input.
99
+ dim_out (`int`): The number of channels in the output.
100
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
101
+ """
102
+
103
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
104
+ super().__init__()
105
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106
+
107
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108
+ if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
109
+ # fp16 gelu not supported on mps before torch 2.0
110
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
111
+ return F.gelu(gate)
112
+
113
+ def forward(self, hidden_states, *args, **kwargs):
114
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
115
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
116
+ deprecate("scale", "1.0.0", deprecation_message)
117
+ hidden_states = self.proj(hidden_states)
118
+ if is_torch_npu_available():
119
+ # using torch_npu.npu_geglu can run faster and save memory on NPU.
120
+ return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
121
+ else:
122
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
123
+ return hidden_states * self.gelu(gate)
124
+
125
+
126
+ class SwiGLU(nn.Module):
127
+ r"""
128
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
129
+ but uses SiLU / Swish instead of GeLU.
130
+
131
+ Parameters:
132
+ dim_in (`int`): The number of channels in the input.
133
+ dim_out (`int`): The number of channels in the output.
134
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
135
+ """
136
+
137
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
138
+ super().__init__()
139
+
140
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
141
+ self.activation = nn.SiLU()
142
+
143
+ def forward(self, hidden_states):
144
+ hidden_states = self.proj(hidden_states)
145
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
146
+ return hidden_states * self.activation(gate)
147
+
148
+
149
+ class ApproximateGELU(nn.Module):
150
+ r"""
151
+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
152
+ [paper](https://arxiv.org/abs/1606.08415).
153
+
154
+ Parameters:
155
+ dim_in (`int`): The number of channels in the input.
156
+ dim_out (`int`): The number of channels in the output.
157
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
158
+ """
159
+
160
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
161
+ super().__init__()
162
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = self.proj(x)
166
+ return x * torch.sigmoid(1.702 * x)
167
+
168
+
169
+ class LinearActivation(nn.Module):
170
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
171
+ super().__init__()
172
+
173
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
174
+ self.activation = get_activation(activation)
175
+
176
+ def forward(self, hidden_states):
177
+ hidden_states = self.proj(hidden_states)
178
+ return self.activation(hidden_states)
icedit/diffusers/models/adapter.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Callable, List, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import logging
22
+ from .modeling_utils import ModelMixin
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class MultiAdapter(ModelMixin):
29
+ r"""
30
+ MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
31
+ user-assigned weighting.
32
+
33
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading
34
+ or saving.
35
+
36
+ Args:
37
+ adapters (`List[T2IAdapter]`, *optional*, defaults to None):
38
+ A list of `T2IAdapter` model instances.
39
+ """
40
+
41
+ def __init__(self, adapters: List["T2IAdapter"]):
42
+ super(MultiAdapter, self).__init__()
43
+
44
+ self.num_adapter = len(adapters)
45
+ self.adapters = nn.ModuleList(adapters)
46
+
47
+ if len(adapters) == 0:
48
+ raise ValueError("Expecting at least one adapter")
49
+
50
+ if len(adapters) == 1:
51
+ raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
52
+
53
+ # The outputs from each adapter are added together with a weight.
54
+ # This means that the change in dimensions from downsampling must
55
+ # be the same for all adapters. Inductively, it also means the
56
+ # downscale_factor and total_downscale_factor must be the same for all
57
+ # adapters.
58
+ first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
59
+ first_adapter_downscale_factor = adapters[0].downscale_factor
60
+ for idx in range(1, len(adapters)):
61
+ if (
62
+ adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
63
+ or adapters[idx].downscale_factor != first_adapter_downscale_factor
64
+ ):
65
+ raise ValueError(
66
+ f"Expecting all adapters to have the same downscaling behavior, but got:\n"
67
+ f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
68
+ f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
69
+ f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
70
+ f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
71
+ )
72
+
73
+ self.total_downscale_factor = first_adapter_total_downscale_factor
74
+ self.downscale_factor = first_adapter_downscale_factor
75
+
76
+ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
77
+ r"""
78
+ Args:
79
+ xs (`torch.Tensor`):
80
+ A tensor of shape (batch, channel, height, width) representing input images for multiple adapter
81
+ models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
82
+ `num_adapter` * number of channel per image.
83
+
84
+ adapter_weights (`List[float]`, *optional*, defaults to None):
85
+ A list of floats representing the weights which will be multiplied by each adapter's output before
86
+ summing them together. If `None`, equal weights will be used for all adapters.
87
+ """
88
+ if adapter_weights is None:
89
+ adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
90
+ else:
91
+ adapter_weights = torch.tensor(adapter_weights)
92
+
93
+ accume_state = None
94
+ for x, w, adapter in zip(xs, adapter_weights, self.adapters):
95
+ features = adapter(x)
96
+ if accume_state is None:
97
+ accume_state = features
98
+ for i in range(len(accume_state)):
99
+ accume_state[i] = w * accume_state[i]
100
+ else:
101
+ for i in range(len(features)):
102
+ accume_state[i] += w * features[i]
103
+ return accume_state
104
+
105
+ def save_pretrained(
106
+ self,
107
+ save_directory: Union[str, os.PathLike],
108
+ is_main_process: bool = True,
109
+ save_function: Callable = None,
110
+ safe_serialization: bool = True,
111
+ variant: Optional[str] = None,
112
+ ):
113
+ """
114
+ Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the
115
+ `[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
116
+
117
+ Args:
118
+ save_directory (`str` or `os.PathLike`):
119
+ The directory where the model will be saved. If the directory does not exist, it will be created.
120
+ is_main_process (`bool`, optional, defaults=True):
121
+ Indicates whether current process is the main process or not. Useful for distributed training (e.g.,
122
+ TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only
123
+ for the main process to avoid race conditions.
124
+ save_function (`Callable`):
125
+ Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace
126
+ `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment
127
+ variable.
128
+ safe_serialization (`bool`, optional, defaults=True):
129
+ If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`.
130
+ variant (`str`, *optional*):
131
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
132
+ """
133
+ idx = 0
134
+ model_path_to_save = save_directory
135
+ for adapter in self.adapters:
136
+ adapter.save_pretrained(
137
+ model_path_to_save,
138
+ is_main_process=is_main_process,
139
+ save_function=save_function,
140
+ safe_serialization=safe_serialization,
141
+ variant=variant,
142
+ )
143
+
144
+ idx += 1
145
+ model_path_to_save = model_path_to_save + f"_{idx}"
146
+
147
+ @classmethod
148
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
149
+ r"""
150
+ Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
151
+
152
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
153
+ the model, set it back to training mode using `model.train()`.
154
+
155
+ Warnings:
156
+ *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained
157
+ with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights
158
+ from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded.
159
+
160
+ Args:
161
+ pretrained_model_path (`os.PathLike`):
162
+ A path to a *directory* containing model weights saved using
163
+ [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
164
+ torch_dtype (`str` or `torch.dtype`, *optional*):
165
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
166
+ will be automatically derived from the model's weights.
167
+ output_loading_info(`bool`, *optional*, defaults to `False`):
168
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
169
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
170
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
171
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
172
+ same device.
173
+
174
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
175
+ more information about each option see [designing a device
176
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
177
+ max_memory (`Dict`, *optional*):
178
+ A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory
179
+ available for each GPU and the available CPU RAM if unset.
180
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
181
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
182
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
183
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
184
+ setting this argument to `True` will raise an error.
185
+ variant (`str`, *optional*):
186
+ If specified, load weights from a `variant` file (*e.g.* pytorch_model.<variant>.bin). `variant` will
187
+ be ignored when using `from_flax`.
188
+ use_safetensors (`bool`, *optional*, defaults to `None`):
189
+ If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is
190
+ installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`,
191
+ `safetensors` is not used.
192
+ """
193
+ idx = 0
194
+ adapters = []
195
+
196
+ # load adapter and append to list until no adapter directory exists anymore
197
+ # first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
198
+ # second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
199
+ model_path_to_load = pretrained_model_path
200
+ while os.path.isdir(model_path_to_load):
201
+ adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
202
+ adapters.append(adapter)
203
+
204
+ idx += 1
205
+ model_path_to_load = pretrained_model_path + f"_{idx}"
206
+
207
+ logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
208
+
209
+ if len(adapters) == 0:
210
+ raise ValueError(
211
+ f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
212
+ )
213
+
214
+ return cls(adapters)
215
+
216
+
217
+ class T2IAdapter(ModelMixin, ConfigMixin):
218
+ r"""
219
+ A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
220
+ generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
221
+ architecture follows the original implementation of
222
+ [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
223
+ and
224
+ [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
225
+
226
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as
227
+ downloading or saving.
228
+
229
+ Args:
230
+ in_channels (`int`, *optional*, defaults to `3`):
231
+ The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
232
+ image.
233
+ channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
234
+ The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
235
+ determines the number of downsample blocks in the adapter.
236
+ num_res_blocks (`int`, *optional*, defaults to `2`):
237
+ Number of ResNet blocks in each downsample block.
238
+ downscale_factor (`int`, *optional*, defaults to `8`):
239
+ A factor that determines the total downscale factor of the Adapter.
240
+ adapter_type (`str`, *optional*, defaults to `full_adapter`):
241
+ Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use.
242
+ """
243
+
244
+ @register_to_config
245
+ def __init__(
246
+ self,
247
+ in_channels: int = 3,
248
+ channels: List[int] = [320, 640, 1280, 1280],
249
+ num_res_blocks: int = 2,
250
+ downscale_factor: int = 8,
251
+ adapter_type: str = "full_adapter",
252
+ ):
253
+ super().__init__()
254
+
255
+ if adapter_type == "full_adapter":
256
+ self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
257
+ elif adapter_type == "full_adapter_xl":
258
+ self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
259
+ elif adapter_type == "light_adapter":
260
+ self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
261
+ else:
262
+ raise ValueError(
263
+ f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
264
+ "'full_adapter_xl' or 'light_adapter'."
265
+ )
266
+
267
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
268
+ r"""
269
+ This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
270
+ each representing information extracted at a different scale from the input. The length of the list is
271
+ determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
272
+ `num_res_blocks` parameters during initialization.
273
+ """
274
+ return self.adapter(x)
275
+
276
+ @property
277
+ def total_downscale_factor(self):
278
+ return self.adapter.total_downscale_factor
279
+
280
+ @property
281
+ def downscale_factor(self):
282
+ """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
283
+ not evenly divisible by the downscale_factor then an exception will be raised.
284
+ """
285
+ return self.adapter.unshuffle.downscale_factor
286
+
287
+
288
+ # full adapter
289
+
290
+
291
+ class FullAdapter(nn.Module):
292
+ r"""
293
+ See [`T2IAdapter`] for more information.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ in_channels: int = 3,
299
+ channels: List[int] = [320, 640, 1280, 1280],
300
+ num_res_blocks: int = 2,
301
+ downscale_factor: int = 8,
302
+ ):
303
+ super().__init__()
304
+
305
+ in_channels = in_channels * downscale_factor**2
306
+
307
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
308
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
309
+
310
+ self.body = nn.ModuleList(
311
+ [
312
+ AdapterBlock(channels[0], channels[0], num_res_blocks),
313
+ *[
314
+ AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
315
+ for i in range(1, len(channels))
316
+ ],
317
+ ]
318
+ )
319
+
320
+ self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
321
+
322
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
323
+ r"""
324
+ This method processes the input tensor `x` through the FullAdapter model and performs operations including
325
+ pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
326
+ capturing information at a different stage of processing within the FullAdapter model. The number of feature
327
+ tensors in the list is determined by the number of downsample blocks specified during initialization.
328
+ """
329
+ x = self.unshuffle(x)
330
+ x = self.conv_in(x)
331
+
332
+ features = []
333
+
334
+ for block in self.body:
335
+ x = block(x)
336
+ features.append(x)
337
+
338
+ return features
339
+
340
+
341
+ class FullAdapterXL(nn.Module):
342
+ r"""
343
+ See [`T2IAdapter`] for more information.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ in_channels: int = 3,
349
+ channels: List[int] = [320, 640, 1280, 1280],
350
+ num_res_blocks: int = 2,
351
+ downscale_factor: int = 16,
352
+ ):
353
+ super().__init__()
354
+
355
+ in_channels = in_channels * downscale_factor**2
356
+
357
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
358
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
359
+
360
+ self.body = []
361
+ # blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
362
+ for i in range(len(channels)):
363
+ if i == 1:
364
+ self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
365
+ elif i == 2:
366
+ self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
367
+ else:
368
+ self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
369
+
370
+ self.body = nn.ModuleList(self.body)
371
+ # XL has only one downsampling AdapterBlock.
372
+ self.total_downscale_factor = downscale_factor * 2
373
+
374
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
375
+ r"""
376
+ This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
377
+ including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
378
+ """
379
+ x = self.unshuffle(x)
380
+ x = self.conv_in(x)
381
+
382
+ features = []
383
+
384
+ for block in self.body:
385
+ x = block(x)
386
+ features.append(x)
387
+
388
+ return features
389
+
390
+
391
+ class AdapterBlock(nn.Module):
392
+ r"""
393
+ An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
394
+ `FullAdapterXL` models.
395
+
396
+ Args:
397
+ in_channels (`int`):
398
+ Number of channels of AdapterBlock's input.
399
+ out_channels (`int`):
400
+ Number of channels of AdapterBlock's output.
401
+ num_res_blocks (`int`):
402
+ Number of ResNet blocks in the AdapterBlock.
403
+ down (`bool`, *optional*, defaults to `False`):
404
+ If `True`, perform downsampling on AdapterBlock's input.
405
+ """
406
+
407
+ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
408
+ super().__init__()
409
+
410
+ self.downsample = None
411
+ if down:
412
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
413
+
414
+ self.in_conv = None
415
+ if in_channels != out_channels:
416
+ self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
417
+
418
+ self.resnets = nn.Sequential(
419
+ *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ r"""
424
+ This method takes tensor x as input and performs operations downsampling and convolutional layers if the
425
+ self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
426
+ residual blocks to the input tensor.
427
+ """
428
+ if self.downsample is not None:
429
+ x = self.downsample(x)
430
+
431
+ if self.in_conv is not None:
432
+ x = self.in_conv(x)
433
+
434
+ x = self.resnets(x)
435
+
436
+ return x
437
+
438
+
439
+ class AdapterResnetBlock(nn.Module):
440
+ r"""
441
+ An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
442
+
443
+ Args:
444
+ channels (`int`):
445
+ Number of channels of AdapterResnetBlock's input and output.
446
+ """
447
+
448
+ def __init__(self, channels: int):
449
+ super().__init__()
450
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
451
+ self.act = nn.ReLU()
452
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
453
+
454
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
455
+ r"""
456
+ This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
457
+ layer on the input tensor. It returns addition with the input tensor.
458
+ """
459
+
460
+ h = self.act(self.block1(x))
461
+ h = self.block2(h)
462
+
463
+ return h + x
464
+
465
+
466
+ # light adapter
467
+
468
+
469
+ class LightAdapter(nn.Module):
470
+ r"""
471
+ See [`T2IAdapter`] for more information.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ in_channels: int = 3,
477
+ channels: List[int] = [320, 640, 1280],
478
+ num_res_blocks: int = 4,
479
+ downscale_factor: int = 8,
480
+ ):
481
+ super().__init__()
482
+
483
+ in_channels = in_channels * downscale_factor**2
484
+
485
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
486
+
487
+ self.body = nn.ModuleList(
488
+ [
489
+ LightAdapterBlock(in_channels, channels[0], num_res_blocks),
490
+ *[
491
+ LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
492
+ for i in range(len(channels) - 1)
493
+ ],
494
+ LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
495
+ ]
496
+ )
497
+
498
+ self.total_downscale_factor = downscale_factor * (2 ** len(channels))
499
+
500
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
501
+ r"""
502
+ This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
503
+ feature tensor corresponds to a different level of processing within the LightAdapter.
504
+ """
505
+ x = self.unshuffle(x)
506
+
507
+ features = []
508
+
509
+ for block in self.body:
510
+ x = block(x)
511
+ features.append(x)
512
+
513
+ return features
514
+
515
+
516
+ class LightAdapterBlock(nn.Module):
517
+ r"""
518
+ A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
519
+ `LightAdapter` model.
520
+
521
+ Args:
522
+ in_channels (`int`):
523
+ Number of channels of LightAdapterBlock's input.
524
+ out_channels (`int`):
525
+ Number of channels of LightAdapterBlock's output.
526
+ num_res_blocks (`int`):
527
+ Number of LightAdapterResnetBlocks in the LightAdapterBlock.
528
+ down (`bool`, *optional*, defaults to `False`):
529
+ If `True`, perform downsampling on LightAdapterBlock's input.
530
+ """
531
+
532
+ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
533
+ super().__init__()
534
+ mid_channels = out_channels // 4
535
+
536
+ self.downsample = None
537
+ if down:
538
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
539
+
540
+ self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
541
+ self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
542
+ self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
543
+
544
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
545
+ r"""
546
+ This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
547
+ layer, a sequence of residual blocks, and out convolutional layer.
548
+ """
549
+ if self.downsample is not None:
550
+ x = self.downsample(x)
551
+
552
+ x = self.in_conv(x)
553
+ x = self.resnets(x)
554
+ x = self.out_conv(x)
555
+
556
+ return x
557
+
558
+
559
+ class LightAdapterResnetBlock(nn.Module):
560
+ """
561
+ A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
562
+ architecture than `AdapterResnetBlock`.
563
+
564
+ Args:
565
+ channels (`int`):
566
+ Number of channels of LightAdapterResnetBlock's input and output.
567
+ """
568
+
569
+ def __init__(self, channels: int):
570
+ super().__init__()
571
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
572
+ self.act = nn.ReLU()
573
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
574
+
575
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
576
+ r"""
577
+ This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
578
+ another convolutional layer and adds it to input tensor.
579
+ """
580
+
581
+ h = self.act(self.block1(x))
582
+ h = self.block2(h)
583
+
584
+ return h + x
icedit/diffusers/models/attention.py ADDED
@@ -0,0 +1,1252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import deprecate, logging
21
+ from ..utils.torch_utils import maybe_allow_in_graph
22
+ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
23
+ from .attention_processor import Attention, JointAttnProcessor2_0
24
+ from .embeddings import SinusoidalPositionalEmbedding
25
+ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ num_attention_heads: int,
107
+ attention_head_dim: int,
108
+ context_pre_only: bool = False,
109
+ qk_norm: Optional[str] = None,
110
+ use_dual_attention: bool = False,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.use_dual_attention = use_dual_attention
115
+ self.context_pre_only = context_pre_only
116
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
117
+
118
+ if use_dual_attention:
119
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
120
+ else:
121
+ self.norm1 = AdaLayerNormZero(dim)
122
+
123
+ if context_norm_type == "ada_norm_continous":
124
+ self.norm1_context = AdaLayerNormContinuous(
125
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
126
+ )
127
+ elif context_norm_type == "ada_norm_zero":
128
+ self.norm1_context = AdaLayerNormZero(dim)
129
+ else:
130
+ raise ValueError(
131
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
132
+ )
133
+
134
+ if hasattr(F, "scaled_dot_product_attention"):
135
+ processor = JointAttnProcessor2_0()
136
+ else:
137
+ raise ValueError(
138
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
139
+ )
140
+
141
+ self.attn = Attention(
142
+ query_dim=dim,
143
+ cross_attention_dim=None,
144
+ added_kv_proj_dim=dim,
145
+ dim_head=attention_head_dim,
146
+ heads=num_attention_heads,
147
+ out_dim=dim,
148
+ context_pre_only=context_pre_only,
149
+ bias=True,
150
+ processor=processor,
151
+ qk_norm=qk_norm,
152
+ eps=1e-6,
153
+ )
154
+
155
+ if use_dual_attention:
156
+ self.attn2 = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ out_dim=dim,
162
+ bias=True,
163
+ processor=processor,
164
+ qk_norm=qk_norm,
165
+ eps=1e-6,
166
+ )
167
+ else:
168
+ self.attn2 = None
169
+
170
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
171
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
172
+
173
+ if not context_pre_only:
174
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
175
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
176
+ else:
177
+ self.norm2_context = None
178
+ self.ff_context = None
179
+
180
+ # let chunk size default to None
181
+ self._chunk_size = None
182
+ self._chunk_dim = 0
183
+
184
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
185
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
186
+ # Sets chunk feed-forward
187
+ self._chunk_size = chunk_size
188
+ self._chunk_dim = dim
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.FloatTensor,
193
+ encoder_hidden_states: torch.FloatTensor,
194
+ temb: torch.FloatTensor,
195
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
196
+ ):
197
+ joint_attention_kwargs = joint_attention_kwargs or {}
198
+ if self.use_dual_attention:
199
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
200
+ hidden_states, emb=temb
201
+ )
202
+ else:
203
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
204
+
205
+ if self.context_pre_only:
206
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
207
+ else:
208
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
209
+ encoder_hidden_states, emb=temb
210
+ )
211
+
212
+ # Attention.
213
+ attn_output, context_attn_output = self.attn(
214
+ hidden_states=norm_hidden_states,
215
+ encoder_hidden_states=norm_encoder_hidden_states,
216
+ **joint_attention_kwargs,
217
+ )
218
+
219
+ # Process attention outputs for the `hidden_states`.
220
+ attn_output = gate_msa.unsqueeze(1) * attn_output
221
+ hidden_states = hidden_states + attn_output
222
+
223
+ if self.use_dual_attention:
224
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
225
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
226
+ hidden_states = hidden_states + attn_output2
227
+
228
+ norm_hidden_states = self.norm2(hidden_states)
229
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
230
+ if self._chunk_size is not None:
231
+ # "feed_forward_chunk_size" can be used to save memory
232
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
233
+ else:
234
+ ff_output = self.ff(norm_hidden_states)
235
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
236
+
237
+ hidden_states = hidden_states + ff_output
238
+
239
+ # Process attention outputs for the `encoder_hidden_states`.
240
+ if self.context_pre_only:
241
+ encoder_hidden_states = None
242
+ else:
243
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
244
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
245
+
246
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
247
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
248
+ if self._chunk_size is not None:
249
+ # "feed_forward_chunk_size" can be used to save memory
250
+ context_ff_output = _chunked_feed_forward(
251
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
252
+ )
253
+ else:
254
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
255
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
256
+
257
+ return encoder_hidden_states, hidden_states
258
+
259
+
260
+ @maybe_allow_in_graph
261
+ class BasicTransformerBlock(nn.Module):
262
+ r"""
263
+ A basic Transformer block.
264
+
265
+ Parameters:
266
+ dim (`int`): The number of channels in the input and output.
267
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
268
+ attention_head_dim (`int`): The number of channels in each head.
269
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
270
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
271
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
272
+ num_embeds_ada_norm (:
273
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
274
+ attention_bias (:
275
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
276
+ only_cross_attention (`bool`, *optional*):
277
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
278
+ double_self_attention (`bool`, *optional*):
279
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
280
+ upcast_attention (`bool`, *optional*):
281
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
282
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
283
+ Whether to use learnable elementwise affine parameters for normalization.
284
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
285
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
286
+ final_dropout (`bool` *optional*, defaults to False):
287
+ Whether to apply a final dropout after the last feed-forward layer.
288
+ attention_type (`str`, *optional*, defaults to `"default"`):
289
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
290
+ positional_embeddings (`str`, *optional*, defaults to `None`):
291
+ The type of positional embeddings to apply to.
292
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
293
+ The maximum number of positional embeddings to apply.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ dim: int,
299
+ num_attention_heads: int,
300
+ attention_head_dim: int,
301
+ dropout=0.0,
302
+ cross_attention_dim: Optional[int] = None,
303
+ activation_fn: str = "geglu",
304
+ num_embeds_ada_norm: Optional[int] = None,
305
+ attention_bias: bool = False,
306
+ only_cross_attention: bool = False,
307
+ double_self_attention: bool = False,
308
+ upcast_attention: bool = False,
309
+ norm_elementwise_affine: bool = True,
310
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
311
+ norm_eps: float = 1e-5,
312
+ final_dropout: bool = False,
313
+ attention_type: str = "default",
314
+ positional_embeddings: Optional[str] = None,
315
+ num_positional_embeddings: Optional[int] = None,
316
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
317
+ ada_norm_bias: Optional[int] = None,
318
+ ff_inner_dim: Optional[int] = None,
319
+ ff_bias: bool = True,
320
+ attention_out_bias: bool = True,
321
+ ):
322
+ super().__init__()
323
+ self.dim = dim
324
+ self.num_attention_heads = num_attention_heads
325
+ self.attention_head_dim = attention_head_dim
326
+ self.dropout = dropout
327
+ self.cross_attention_dim = cross_attention_dim
328
+ self.activation_fn = activation_fn
329
+ self.attention_bias = attention_bias
330
+ self.double_self_attention = double_self_attention
331
+ self.norm_elementwise_affine = norm_elementwise_affine
332
+ self.positional_embeddings = positional_embeddings
333
+ self.num_positional_embeddings = num_positional_embeddings
334
+ self.only_cross_attention = only_cross_attention
335
+
336
+ # We keep these boolean flags for backward-compatibility.
337
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
338
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
339
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
340
+ self.use_layer_norm = norm_type == "layer_norm"
341
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
342
+
343
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
344
+ raise ValueError(
345
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
346
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
347
+ )
348
+
349
+ self.norm_type = norm_type
350
+ self.num_embeds_ada_norm = num_embeds_ada_norm
351
+
352
+ if positional_embeddings and (num_positional_embeddings is None):
353
+ raise ValueError(
354
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
355
+ )
356
+
357
+ if positional_embeddings == "sinusoidal":
358
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
359
+ else:
360
+ self.pos_embed = None
361
+
362
+ # Define 3 blocks. Each block has its own normalization layer.
363
+ # 1. Self-Attn
364
+ if norm_type == "ada_norm":
365
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
366
+ elif norm_type == "ada_norm_zero":
367
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
368
+ elif norm_type == "ada_norm_continuous":
369
+ self.norm1 = AdaLayerNormContinuous(
370
+ dim,
371
+ ada_norm_continous_conditioning_embedding_dim,
372
+ norm_elementwise_affine,
373
+ norm_eps,
374
+ ada_norm_bias,
375
+ "rms_norm",
376
+ )
377
+ else:
378
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
379
+
380
+ self.attn1 = Attention(
381
+ query_dim=dim,
382
+ heads=num_attention_heads,
383
+ dim_head=attention_head_dim,
384
+ dropout=dropout,
385
+ bias=attention_bias,
386
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
387
+ upcast_attention=upcast_attention,
388
+ out_bias=attention_out_bias,
389
+ )
390
+
391
+ # 2. Cross-Attn
392
+ if cross_attention_dim is not None or double_self_attention:
393
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
394
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
395
+ # the second cross attention block.
396
+ if norm_type == "ada_norm":
397
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
398
+ elif norm_type == "ada_norm_continuous":
399
+ self.norm2 = AdaLayerNormContinuous(
400
+ dim,
401
+ ada_norm_continous_conditioning_embedding_dim,
402
+ norm_elementwise_affine,
403
+ norm_eps,
404
+ ada_norm_bias,
405
+ "rms_norm",
406
+ )
407
+ else:
408
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
409
+
410
+ self.attn2 = Attention(
411
+ query_dim=dim,
412
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
413
+ heads=num_attention_heads,
414
+ dim_head=attention_head_dim,
415
+ dropout=dropout,
416
+ bias=attention_bias,
417
+ upcast_attention=upcast_attention,
418
+ out_bias=attention_out_bias,
419
+ ) # is self-attn if encoder_hidden_states is none
420
+ else:
421
+ if norm_type == "ada_norm_single": # For Latte
422
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
423
+ else:
424
+ self.norm2 = None
425
+ self.attn2 = None
426
+
427
+ # 3. Feed-forward
428
+ if norm_type == "ada_norm_continuous":
429
+ self.norm3 = AdaLayerNormContinuous(
430
+ dim,
431
+ ada_norm_continous_conditioning_embedding_dim,
432
+ norm_elementwise_affine,
433
+ norm_eps,
434
+ ada_norm_bias,
435
+ "layer_norm",
436
+ )
437
+
438
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
439
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
440
+ elif norm_type == "layer_norm_i2vgen":
441
+ self.norm3 = None
442
+
443
+ self.ff = FeedForward(
444
+ dim,
445
+ dropout=dropout,
446
+ activation_fn=activation_fn,
447
+ final_dropout=final_dropout,
448
+ inner_dim=ff_inner_dim,
449
+ bias=ff_bias,
450
+ )
451
+
452
+ # 4. Fuser
453
+ if attention_type == "gated" or attention_type == "gated-text-image":
454
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
455
+
456
+ # 5. Scale-shift for PixArt-Alpha.
457
+ if norm_type == "ada_norm_single":
458
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
459
+
460
+ # let chunk size default to None
461
+ self._chunk_size = None
462
+ self._chunk_dim = 0
463
+
464
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
465
+ # Sets chunk feed-forward
466
+ self._chunk_size = chunk_size
467
+ self._chunk_dim = dim
468
+
469
+ def forward(
470
+ self,
471
+ hidden_states: torch.Tensor,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ encoder_hidden_states: Optional[torch.Tensor] = None,
474
+ encoder_attention_mask: Optional[torch.Tensor] = None,
475
+ timestep: Optional[torch.LongTensor] = None,
476
+ cross_attention_kwargs: Dict[str, Any] = None,
477
+ class_labels: Optional[torch.LongTensor] = None,
478
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
479
+ ) -> torch.Tensor:
480
+ if cross_attention_kwargs is not None:
481
+ if cross_attention_kwargs.get("scale", None) is not None:
482
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
483
+
484
+ # Notice that normalization is always applied before the real computation in the following blocks.
485
+ # 0. Self-Attention
486
+ batch_size = hidden_states.shape[0]
487
+
488
+ if self.norm_type == "ada_norm":
489
+ norm_hidden_states = self.norm1(hidden_states, timestep)
490
+ elif self.norm_type == "ada_norm_zero":
491
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
492
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
493
+ )
494
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
495
+ norm_hidden_states = self.norm1(hidden_states)
496
+ elif self.norm_type == "ada_norm_continuous":
497
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
498
+ elif self.norm_type == "ada_norm_single":
499
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
500
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
501
+ ).chunk(6, dim=1)
502
+ norm_hidden_states = self.norm1(hidden_states)
503
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
504
+ else:
505
+ raise ValueError("Incorrect norm used")
506
+
507
+ if self.pos_embed is not None:
508
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
509
+
510
+ # 1. Prepare GLIGEN inputs
511
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
512
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
513
+
514
+ attn_output = self.attn1(
515
+ norm_hidden_states,
516
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
517
+ attention_mask=attention_mask,
518
+ **cross_attention_kwargs,
519
+ )
520
+
521
+ if self.norm_type == "ada_norm_zero":
522
+ attn_output = gate_msa.unsqueeze(1) * attn_output
523
+ elif self.norm_type == "ada_norm_single":
524
+ attn_output = gate_msa * attn_output
525
+
526
+ hidden_states = attn_output + hidden_states
527
+ if hidden_states.ndim == 4:
528
+ hidden_states = hidden_states.squeeze(1)
529
+
530
+ # 1.2 GLIGEN Control
531
+ if gligen_kwargs is not None:
532
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
533
+
534
+ # 3. Cross-Attention
535
+ if self.attn2 is not None:
536
+ if self.norm_type == "ada_norm":
537
+ norm_hidden_states = self.norm2(hidden_states, timestep)
538
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
539
+ norm_hidden_states = self.norm2(hidden_states)
540
+ elif self.norm_type == "ada_norm_single":
541
+ # For PixArt norm2 isn't applied here:
542
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
543
+ norm_hidden_states = hidden_states
544
+ elif self.norm_type == "ada_norm_continuous":
545
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
546
+ else:
547
+ raise ValueError("Incorrect norm")
548
+
549
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
550
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
551
+
552
+ attn_output = self.attn2(
553
+ norm_hidden_states,
554
+ encoder_hidden_states=encoder_hidden_states,
555
+ attention_mask=encoder_attention_mask,
556
+ **cross_attention_kwargs,
557
+ )
558
+ hidden_states = attn_output + hidden_states
559
+
560
+ # 4. Feed-forward
561
+ # i2vgen doesn't have this norm 🤷‍♂️
562
+ if self.norm_type == "ada_norm_continuous":
563
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
564
+ elif not self.norm_type == "ada_norm_single":
565
+ norm_hidden_states = self.norm3(hidden_states)
566
+
567
+ if self.norm_type == "ada_norm_zero":
568
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
569
+
570
+ if self.norm_type == "ada_norm_single":
571
+ norm_hidden_states = self.norm2(hidden_states)
572
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
573
+
574
+ if self._chunk_size is not None:
575
+ # "feed_forward_chunk_size" can be used to save memory
576
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
577
+ else:
578
+ ff_output = self.ff(norm_hidden_states)
579
+
580
+ if self.norm_type == "ada_norm_zero":
581
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
582
+ elif self.norm_type == "ada_norm_single":
583
+ ff_output = gate_mlp * ff_output
584
+
585
+ hidden_states = ff_output + hidden_states
586
+ if hidden_states.ndim == 4:
587
+ hidden_states = hidden_states.squeeze(1)
588
+
589
+ return hidden_states
590
+
591
+
592
+ class LuminaFeedForward(nn.Module):
593
+ r"""
594
+ A feed-forward layer.
595
+
596
+ Parameters:
597
+ hidden_size (`int`):
598
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
599
+ hidden representations.
600
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
601
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
602
+ of this value.
603
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
604
+ dimension. Defaults to None.
605
+ """
606
+
607
+ def __init__(
608
+ self,
609
+ dim: int,
610
+ inner_dim: int,
611
+ multiple_of: Optional[int] = 256,
612
+ ffn_dim_multiplier: Optional[float] = None,
613
+ ):
614
+ super().__init__()
615
+ inner_dim = int(2 * inner_dim / 3)
616
+ # custom hidden_size factor multiplier
617
+ if ffn_dim_multiplier is not None:
618
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
619
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
620
+
621
+ self.linear_1 = nn.Linear(
622
+ dim,
623
+ inner_dim,
624
+ bias=False,
625
+ )
626
+ self.linear_2 = nn.Linear(
627
+ inner_dim,
628
+ dim,
629
+ bias=False,
630
+ )
631
+ self.linear_3 = nn.Linear(
632
+ dim,
633
+ inner_dim,
634
+ bias=False,
635
+ )
636
+ self.silu = FP32SiLU()
637
+
638
+ def forward(self, x):
639
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
640
+
641
+
642
+ @maybe_allow_in_graph
643
+ class TemporalBasicTransformerBlock(nn.Module):
644
+ r"""
645
+ A basic Transformer block for video like data.
646
+
647
+ Parameters:
648
+ dim (`int`): The number of channels in the input and output.
649
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
650
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
651
+ attention_head_dim (`int`): The number of channels in each head.
652
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
653
+ """
654
+
655
+ def __init__(
656
+ self,
657
+ dim: int,
658
+ time_mix_inner_dim: int,
659
+ num_attention_heads: int,
660
+ attention_head_dim: int,
661
+ cross_attention_dim: Optional[int] = None,
662
+ ):
663
+ super().__init__()
664
+ self.is_res = dim == time_mix_inner_dim
665
+
666
+ self.norm_in = nn.LayerNorm(dim)
667
+
668
+ # Define 3 blocks. Each block has its own normalization layer.
669
+ # 1. Self-Attn
670
+ self.ff_in = FeedForward(
671
+ dim,
672
+ dim_out=time_mix_inner_dim,
673
+ activation_fn="geglu",
674
+ )
675
+
676
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
677
+ self.attn1 = Attention(
678
+ query_dim=time_mix_inner_dim,
679
+ heads=num_attention_heads,
680
+ dim_head=attention_head_dim,
681
+ cross_attention_dim=None,
682
+ )
683
+
684
+ # 2. Cross-Attn
685
+ if cross_attention_dim is not None:
686
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
687
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
688
+ # the second cross attention block.
689
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
690
+ self.attn2 = Attention(
691
+ query_dim=time_mix_inner_dim,
692
+ cross_attention_dim=cross_attention_dim,
693
+ heads=num_attention_heads,
694
+ dim_head=attention_head_dim,
695
+ ) # is self-attn if encoder_hidden_states is none
696
+ else:
697
+ self.norm2 = None
698
+ self.attn2 = None
699
+
700
+ # 3. Feed-forward
701
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
702
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
703
+
704
+ # let chunk size default to None
705
+ self._chunk_size = None
706
+ self._chunk_dim = None
707
+
708
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
709
+ # Sets chunk feed-forward
710
+ self._chunk_size = chunk_size
711
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
712
+ self._chunk_dim = 1
713
+
714
+ def forward(
715
+ self,
716
+ hidden_states: torch.Tensor,
717
+ num_frames: int,
718
+ encoder_hidden_states: Optional[torch.Tensor] = None,
719
+ ) -> torch.Tensor:
720
+ # Notice that normalization is always applied before the real computation in the following blocks.
721
+ # 0. Self-Attention
722
+ batch_size = hidden_states.shape[0]
723
+
724
+ batch_frames, seq_length, channels = hidden_states.shape
725
+ batch_size = batch_frames // num_frames
726
+
727
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
728
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
729
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
730
+
731
+ residual = hidden_states
732
+ hidden_states = self.norm_in(hidden_states)
733
+
734
+ if self._chunk_size is not None:
735
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
736
+ else:
737
+ hidden_states = self.ff_in(hidden_states)
738
+
739
+ if self.is_res:
740
+ hidden_states = hidden_states + residual
741
+
742
+ norm_hidden_states = self.norm1(hidden_states)
743
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
744
+ hidden_states = attn_output + hidden_states
745
+
746
+ # 3. Cross-Attention
747
+ if self.attn2 is not None:
748
+ norm_hidden_states = self.norm2(hidden_states)
749
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
750
+ hidden_states = attn_output + hidden_states
751
+
752
+ # 4. Feed-forward
753
+ norm_hidden_states = self.norm3(hidden_states)
754
+
755
+ if self._chunk_size is not None:
756
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
757
+ else:
758
+ ff_output = self.ff(norm_hidden_states)
759
+
760
+ if self.is_res:
761
+ hidden_states = ff_output + hidden_states
762
+ else:
763
+ hidden_states = ff_output
764
+
765
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
766
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
767
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
768
+
769
+ return hidden_states
770
+
771
+
772
+ class SkipFFTransformerBlock(nn.Module):
773
+ def __init__(
774
+ self,
775
+ dim: int,
776
+ num_attention_heads: int,
777
+ attention_head_dim: int,
778
+ kv_input_dim: int,
779
+ kv_input_dim_proj_use_bias: bool,
780
+ dropout=0.0,
781
+ cross_attention_dim: Optional[int] = None,
782
+ attention_bias: bool = False,
783
+ attention_out_bias: bool = True,
784
+ ):
785
+ super().__init__()
786
+ if kv_input_dim != dim:
787
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
788
+ else:
789
+ self.kv_mapper = None
790
+
791
+ self.norm1 = RMSNorm(dim, 1e-06)
792
+
793
+ self.attn1 = Attention(
794
+ query_dim=dim,
795
+ heads=num_attention_heads,
796
+ dim_head=attention_head_dim,
797
+ dropout=dropout,
798
+ bias=attention_bias,
799
+ cross_attention_dim=cross_attention_dim,
800
+ out_bias=attention_out_bias,
801
+ )
802
+
803
+ self.norm2 = RMSNorm(dim, 1e-06)
804
+
805
+ self.attn2 = Attention(
806
+ query_dim=dim,
807
+ cross_attention_dim=cross_attention_dim,
808
+ heads=num_attention_heads,
809
+ dim_head=attention_head_dim,
810
+ dropout=dropout,
811
+ bias=attention_bias,
812
+ out_bias=attention_out_bias,
813
+ )
814
+
815
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
816
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
817
+
818
+ if self.kv_mapper is not None:
819
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
820
+
821
+ norm_hidden_states = self.norm1(hidden_states)
822
+
823
+ attn_output = self.attn1(
824
+ norm_hidden_states,
825
+ encoder_hidden_states=encoder_hidden_states,
826
+ **cross_attention_kwargs,
827
+ )
828
+
829
+ hidden_states = attn_output + hidden_states
830
+
831
+ norm_hidden_states = self.norm2(hidden_states)
832
+
833
+ attn_output = self.attn2(
834
+ norm_hidden_states,
835
+ encoder_hidden_states=encoder_hidden_states,
836
+ **cross_attention_kwargs,
837
+ )
838
+
839
+ hidden_states = attn_output + hidden_states
840
+
841
+ return hidden_states
842
+
843
+
844
+ @maybe_allow_in_graph
845
+ class FreeNoiseTransformerBlock(nn.Module):
846
+ r"""
847
+ A FreeNoise Transformer block.
848
+
849
+ Parameters:
850
+ dim (`int`):
851
+ The number of channels in the input and output.
852
+ num_attention_heads (`int`):
853
+ The number of heads to use for multi-head attention.
854
+ attention_head_dim (`int`):
855
+ The number of channels in each head.
856
+ dropout (`float`, *optional*, defaults to 0.0):
857
+ The dropout probability to use.
858
+ cross_attention_dim (`int`, *optional*):
859
+ The size of the encoder_hidden_states vector for cross attention.
860
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
861
+ Activation function to be used in feed-forward.
862
+ num_embeds_ada_norm (`int`, *optional*):
863
+ The number of diffusion steps used during training. See `Transformer2DModel`.
864
+ attention_bias (`bool`, defaults to `False`):
865
+ Configure if the attentions should contain a bias parameter.
866
+ only_cross_attention (`bool`, defaults to `False`):
867
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
868
+ double_self_attention (`bool`, defaults to `False`):
869
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
870
+ upcast_attention (`bool`, defaults to `False`):
871
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
872
+ norm_elementwise_affine (`bool`, defaults to `True`):
873
+ Whether to use learnable elementwise affine parameters for normalization.
874
+ norm_type (`str`, defaults to `"layer_norm"`):
875
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
876
+ final_dropout (`bool` defaults to `False`):
877
+ Whether to apply a final dropout after the last feed-forward layer.
878
+ attention_type (`str`, defaults to `"default"`):
879
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
880
+ positional_embeddings (`str`, *optional*):
881
+ The type of positional embeddings to apply to.
882
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
883
+ The maximum number of positional embeddings to apply.
884
+ ff_inner_dim (`int`, *optional*):
885
+ Hidden dimension of feed-forward MLP.
886
+ ff_bias (`bool`, defaults to `True`):
887
+ Whether or not to use bias in feed-forward MLP.
888
+ attention_out_bias (`bool`, defaults to `True`):
889
+ Whether or not to use bias in attention output project layer.
890
+ context_length (`int`, defaults to `16`):
891
+ The maximum number of frames that the FreeNoise block processes at once.
892
+ context_stride (`int`, defaults to `4`):
893
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
894
+ weighting_scheme (`str`, defaults to `"pyramid"`):
895
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
896
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
897
+ used.
898
+ """
899
+
900
+ def __init__(
901
+ self,
902
+ dim: int,
903
+ num_attention_heads: int,
904
+ attention_head_dim: int,
905
+ dropout: float = 0.0,
906
+ cross_attention_dim: Optional[int] = None,
907
+ activation_fn: str = "geglu",
908
+ num_embeds_ada_norm: Optional[int] = None,
909
+ attention_bias: bool = False,
910
+ only_cross_attention: bool = False,
911
+ double_self_attention: bool = False,
912
+ upcast_attention: bool = False,
913
+ norm_elementwise_affine: bool = True,
914
+ norm_type: str = "layer_norm",
915
+ norm_eps: float = 1e-5,
916
+ final_dropout: bool = False,
917
+ positional_embeddings: Optional[str] = None,
918
+ num_positional_embeddings: Optional[int] = None,
919
+ ff_inner_dim: Optional[int] = None,
920
+ ff_bias: bool = True,
921
+ attention_out_bias: bool = True,
922
+ context_length: int = 16,
923
+ context_stride: int = 4,
924
+ weighting_scheme: str = "pyramid",
925
+ ):
926
+ super().__init__()
927
+ self.dim = dim
928
+ self.num_attention_heads = num_attention_heads
929
+ self.attention_head_dim = attention_head_dim
930
+ self.dropout = dropout
931
+ self.cross_attention_dim = cross_attention_dim
932
+ self.activation_fn = activation_fn
933
+ self.attention_bias = attention_bias
934
+ self.double_self_attention = double_self_attention
935
+ self.norm_elementwise_affine = norm_elementwise_affine
936
+ self.positional_embeddings = positional_embeddings
937
+ self.num_positional_embeddings = num_positional_embeddings
938
+ self.only_cross_attention = only_cross_attention
939
+
940
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
941
+
942
+ # We keep these boolean flags for backward-compatibility.
943
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
944
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
945
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
946
+ self.use_layer_norm = norm_type == "layer_norm"
947
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
948
+
949
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
950
+ raise ValueError(
951
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
952
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
953
+ )
954
+
955
+ self.norm_type = norm_type
956
+ self.num_embeds_ada_norm = num_embeds_ada_norm
957
+
958
+ if positional_embeddings and (num_positional_embeddings is None):
959
+ raise ValueError(
960
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
961
+ )
962
+
963
+ if positional_embeddings == "sinusoidal":
964
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
965
+ else:
966
+ self.pos_embed = None
967
+
968
+ # Define 3 blocks. Each block has its own normalization layer.
969
+ # 1. Self-Attn
970
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
971
+
972
+ self.attn1 = Attention(
973
+ query_dim=dim,
974
+ heads=num_attention_heads,
975
+ dim_head=attention_head_dim,
976
+ dropout=dropout,
977
+ bias=attention_bias,
978
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
979
+ upcast_attention=upcast_attention,
980
+ out_bias=attention_out_bias,
981
+ )
982
+
983
+ # 2. Cross-Attn
984
+ if cross_attention_dim is not None or double_self_attention:
985
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
986
+
987
+ self.attn2 = Attention(
988
+ query_dim=dim,
989
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
990
+ heads=num_attention_heads,
991
+ dim_head=attention_head_dim,
992
+ dropout=dropout,
993
+ bias=attention_bias,
994
+ upcast_attention=upcast_attention,
995
+ out_bias=attention_out_bias,
996
+ ) # is self-attn if encoder_hidden_states is none
997
+
998
+ # 3. Feed-forward
999
+ self.ff = FeedForward(
1000
+ dim,
1001
+ dropout=dropout,
1002
+ activation_fn=activation_fn,
1003
+ final_dropout=final_dropout,
1004
+ inner_dim=ff_inner_dim,
1005
+ bias=ff_bias,
1006
+ )
1007
+
1008
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1009
+
1010
+ # let chunk size default to None
1011
+ self._chunk_size = None
1012
+ self._chunk_dim = 0
1013
+
1014
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1015
+ frame_indices = []
1016
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1017
+ window_start = i
1018
+ window_end = min(num_frames, i + self.context_length)
1019
+ frame_indices.append((window_start, window_end))
1020
+ return frame_indices
1021
+
1022
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1023
+ if weighting_scheme == "flat":
1024
+ weights = [1.0] * num_frames
1025
+
1026
+ elif weighting_scheme == "pyramid":
1027
+ if num_frames % 2 == 0:
1028
+ # num_frames = 4 => [1, 2, 2, 1]
1029
+ mid = num_frames // 2
1030
+ weights = list(range(1, mid + 1))
1031
+ weights = weights + weights[::-1]
1032
+ else:
1033
+ # num_frames = 5 => [1, 2, 3, 2, 1]
1034
+ mid = (num_frames + 1) // 2
1035
+ weights = list(range(1, mid))
1036
+ weights = weights + [mid] + weights[::-1]
1037
+
1038
+ elif weighting_scheme == "delayed_reverse_sawtooth":
1039
+ if num_frames % 2 == 0:
1040
+ # num_frames = 4 => [0.01, 2, 2, 1]
1041
+ mid = num_frames // 2
1042
+ weights = [0.01] * (mid - 1) + [mid]
1043
+ weights = weights + list(range(mid, 0, -1))
1044
+ else:
1045
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1046
+ mid = (num_frames + 1) // 2
1047
+ weights = [0.01] * mid
1048
+ weights = weights + list(range(mid, 0, -1))
1049
+ else:
1050
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1051
+
1052
+ return weights
1053
+
1054
+ def set_free_noise_properties(
1055
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1056
+ ) -> None:
1057
+ self.context_length = context_length
1058
+ self.context_stride = context_stride
1059
+ self.weighting_scheme = weighting_scheme
1060
+
1061
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1062
+ # Sets chunk feed-forward
1063
+ self._chunk_size = chunk_size
1064
+ self._chunk_dim = dim
1065
+
1066
+ def forward(
1067
+ self,
1068
+ hidden_states: torch.Tensor,
1069
+ attention_mask: Optional[torch.Tensor] = None,
1070
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1071
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1072
+ cross_attention_kwargs: Dict[str, Any] = None,
1073
+ *args,
1074
+ **kwargs,
1075
+ ) -> torch.Tensor:
1076
+ if cross_attention_kwargs is not None:
1077
+ if cross_attention_kwargs.get("scale", None) is not None:
1078
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1079
+
1080
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1081
+
1082
+ # hidden_states: [B x H x W, F, C]
1083
+ device = hidden_states.device
1084
+ dtype = hidden_states.dtype
1085
+
1086
+ num_frames = hidden_states.size(1)
1087
+ frame_indices = self._get_frame_indices(num_frames)
1088
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1089
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1090
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1091
+
1092
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1093
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1094
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1095
+ if not is_last_frame_batch_complete:
1096
+ if num_frames < self.context_length:
1097
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1098
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1099
+ frame_indices.append((num_frames - self.context_length, num_frames))
1100
+
1101
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1102
+ accumulated_values = torch.zeros_like(hidden_states)
1103
+
1104
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1105
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1106
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1107
+ # essentially a non-multiple of `context_length`.
1108
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1109
+ weights *= frame_weights
1110
+
1111
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1112
+
1113
+ # Notice that normalization is always applied before the real computation in the following blocks.
1114
+ # 1. Self-Attention
1115
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1116
+
1117
+ if self.pos_embed is not None:
1118
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1119
+
1120
+ attn_output = self.attn1(
1121
+ norm_hidden_states,
1122
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1123
+ attention_mask=attention_mask,
1124
+ **cross_attention_kwargs,
1125
+ )
1126
+
1127
+ hidden_states_chunk = attn_output + hidden_states_chunk
1128
+ if hidden_states_chunk.ndim == 4:
1129
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1130
+
1131
+ # 2. Cross-Attention
1132
+ if self.attn2 is not None:
1133
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1134
+
1135
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1136
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1137
+
1138
+ attn_output = self.attn2(
1139
+ norm_hidden_states,
1140
+ encoder_hidden_states=encoder_hidden_states,
1141
+ attention_mask=encoder_attention_mask,
1142
+ **cross_attention_kwargs,
1143
+ )
1144
+ hidden_states_chunk = attn_output + hidden_states_chunk
1145
+
1146
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1147
+ accumulated_values[:, -last_frame_batch_length:] += (
1148
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1149
+ )
1150
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1151
+ else:
1152
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1153
+ num_times_accumulated[:, frame_start:frame_end] += weights
1154
+
1155
+ # TODO(aryan): Maybe this could be done in a better way.
1156
+ #
1157
+ # Previously, this was:
1158
+ # hidden_states = torch.where(
1159
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1160
+ # )
1161
+ #
1162
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1163
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1164
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1165
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1166
+ hidden_states = torch.cat(
1167
+ [
1168
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1169
+ for accumulated_split, num_times_split in zip(
1170
+ accumulated_values.split(self.context_length, dim=1),
1171
+ num_times_accumulated.split(self.context_length, dim=1),
1172
+ )
1173
+ ],
1174
+ dim=1,
1175
+ ).to(dtype)
1176
+
1177
+ # 3. Feed-forward
1178
+ norm_hidden_states = self.norm3(hidden_states)
1179
+
1180
+ if self._chunk_size is not None:
1181
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1182
+ else:
1183
+ ff_output = self.ff(norm_hidden_states)
1184
+
1185
+ hidden_states = ff_output + hidden_states
1186
+ if hidden_states.ndim == 4:
1187
+ hidden_states = hidden_states.squeeze(1)
1188
+
1189
+ return hidden_states
1190
+
1191
+
1192
+ class FeedForward(nn.Module):
1193
+ r"""
1194
+ A feed-forward layer.
1195
+
1196
+ Parameters:
1197
+ dim (`int`): The number of channels in the input.
1198
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1199
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1200
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1201
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1202
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1203
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1204
+ """
1205
+
1206
+ def __init__(
1207
+ self,
1208
+ dim: int,
1209
+ dim_out: Optional[int] = None,
1210
+ mult: int = 4,
1211
+ dropout: float = 0.0,
1212
+ activation_fn: str = "geglu",
1213
+ final_dropout: bool = False,
1214
+ inner_dim=None,
1215
+ bias: bool = True,
1216
+ ):
1217
+ super().__init__()
1218
+ if inner_dim is None:
1219
+ inner_dim = int(dim * mult)
1220
+ dim_out = dim_out if dim_out is not None else dim
1221
+
1222
+ if activation_fn == "gelu":
1223
+ act_fn = GELU(dim, inner_dim, bias=bias)
1224
+ if activation_fn == "gelu-approximate":
1225
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1226
+ elif activation_fn == "geglu":
1227
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1228
+ elif activation_fn == "geglu-approximate":
1229
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1230
+ elif activation_fn == "swiglu":
1231
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1232
+ elif activation_fn == "linear-silu":
1233
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
1234
+
1235
+ self.net = nn.ModuleList([])
1236
+ # project in
1237
+ self.net.append(act_fn)
1238
+ # project dropout
1239
+ self.net.append(nn.Dropout(dropout))
1240
+ # project out
1241
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1242
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1243
+ if final_dropout:
1244
+ self.net.append(nn.Dropout(dropout))
1245
+
1246
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1247
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1248
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1249
+ deprecate("scale", "1.0.0", deprecation_message)
1250
+ for module in self.net:
1251
+ hidden_states = module(hidden_states)
1252
+ return hidden_states
icedit/diffusers/models/attention_flax.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+
22
+
23
+ def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24
+ """Multi-head dot product attention with a limited number of queries."""
25
+ num_kv, num_heads, k_features = key.shape[-3:]
26
+ v_features = value.shape[-1]
27
+ key_chunk_size = min(key_chunk_size, num_kv)
28
+ query = query / jnp.sqrt(k_features)
29
+
30
+ @functools.partial(jax.checkpoint, prevent_cse=False)
31
+ def summarize_chunk(query, key, value):
32
+ attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33
+
34
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35
+ max_score = jax.lax.stop_gradient(max_score)
36
+ exp_weights = jnp.exp(attn_weights - max_score)
37
+
38
+ exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39
+ max_score = jnp.einsum("...qhk->...qh", max_score)
40
+
41
+ return (exp_values, exp_weights.sum(axis=-1), max_score)
42
+
43
+ def chunk_scanner(chunk_idx):
44
+ # julienne key array
45
+ key_chunk = jax.lax.dynamic_slice(
46
+ operand=key,
47
+ start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48
+ slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49
+ )
50
+
51
+ # julienne value array
52
+ value_chunk = jax.lax.dynamic_slice(
53
+ operand=value,
54
+ start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55
+ slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56
+ )
57
+
58
+ return summarize_chunk(query, key_chunk, value_chunk)
59
+
60
+ chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61
+
62
+ global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63
+ max_diffs = jnp.exp(chunk_max - global_max)
64
+
65
+ chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66
+ chunk_weights *= max_diffs
67
+
68
+ all_values = chunk_values.sum(axis=0)
69
+ all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70
+
71
+ return all_values / all_weights
72
+
73
+
74
+ def jax_memory_efficient_attention(
75
+ query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76
+ ):
77
+ r"""
78
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
+ https://github.com/AminRezaei0x443/memory-efficient-attention
80
+
81
+ Args:
82
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
+ numerical precision for computation
87
+ query_chunk_size (`int`, *optional*, defaults to 1024):
88
+ chunk size to divide query array value must divide query_length equally without remainder
89
+ key_chunk_size (`int`, *optional*, defaults to 4096):
90
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
91
+
92
+ Returns:
93
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
+ """
95
+ num_q, num_heads, q_features = query.shape[-3:]
96
+
97
+ def chunk_scanner(chunk_idx, _):
98
+ # julienne query array
99
+ query_chunk = jax.lax.dynamic_slice(
100
+ operand=query,
101
+ start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102
+ slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103
+ )
104
+
105
+ return (
106
+ chunk_idx + query_chunk_size, # unused ignore it
107
+ _query_chunk_attention(
108
+ query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109
+ ),
110
+ )
111
+
112
+ _, res = jax.lax.scan(
113
+ f=chunk_scanner,
114
+ init=0,
115
+ xs=None,
116
+ length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
117
+ )
118
+
119
+ return jnp.concatenate(res, axis=-3) # fuse the chunked result back
120
+
121
+
122
+ class FlaxAttention(nn.Module):
123
+ r"""
124
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
125
+
126
+ Parameters:
127
+ query_dim (:obj:`int`):
128
+ Input hidden states dimension
129
+ heads (:obj:`int`, *optional*, defaults to 8):
130
+ Number of heads
131
+ dim_head (:obj:`int`, *optional*, defaults to 64):
132
+ Hidden states dimension inside each head
133
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
134
+ Dropout rate
135
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
136
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
137
+ split_head_dim (`bool`, *optional*, defaults to `False`):
138
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
139
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
140
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
141
+ Parameters `dtype`
142
+
143
+ """
144
+
145
+ query_dim: int
146
+ heads: int = 8
147
+ dim_head: int = 64
148
+ dropout: float = 0.0
149
+ use_memory_efficient_attention: bool = False
150
+ split_head_dim: bool = False
151
+ dtype: jnp.dtype = jnp.float32
152
+
153
+ def setup(self):
154
+ inner_dim = self.dim_head * self.heads
155
+ self.scale = self.dim_head**-0.5
156
+
157
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
158
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
159
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
160
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
161
+
162
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
163
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
164
+
165
+ def reshape_heads_to_batch_dim(self, tensor):
166
+ batch_size, seq_len, dim = tensor.shape
167
+ head_size = self.heads
168
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
169
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
170
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
171
+ return tensor
172
+
173
+ def reshape_batch_dim_to_heads(self, tensor):
174
+ batch_size, seq_len, dim = tensor.shape
175
+ head_size = self.heads
176
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
177
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
178
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
179
+ return tensor
180
+
181
+ def __call__(self, hidden_states, context=None, deterministic=True):
182
+ context = hidden_states if context is None else context
183
+
184
+ query_proj = self.query(hidden_states)
185
+ key_proj = self.key(context)
186
+ value_proj = self.value(context)
187
+
188
+ if self.split_head_dim:
189
+ b = hidden_states.shape[0]
190
+ query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
191
+ key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
192
+ value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
193
+ else:
194
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
195
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
196
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
197
+
198
+ if self.use_memory_efficient_attention:
199
+ query_states = query_states.transpose(1, 0, 2)
200
+ key_states = key_states.transpose(1, 0, 2)
201
+ value_states = value_states.transpose(1, 0, 2)
202
+
203
+ # this if statement create a chunk size for each layer of the unet
204
+ # the chunk size is equal to the query_length dimension of the deepest layer of the unet
205
+
206
+ flatten_latent_dim = query_states.shape[-3]
207
+ if flatten_latent_dim % 64 == 0:
208
+ query_chunk_size = int(flatten_latent_dim / 64)
209
+ elif flatten_latent_dim % 16 == 0:
210
+ query_chunk_size = int(flatten_latent_dim / 16)
211
+ elif flatten_latent_dim % 4 == 0:
212
+ query_chunk_size = int(flatten_latent_dim / 4)
213
+ else:
214
+ query_chunk_size = int(flatten_latent_dim)
215
+
216
+ hidden_states = jax_memory_efficient_attention(
217
+ query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
218
+ )
219
+ hidden_states = hidden_states.transpose(1, 0, 2)
220
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
221
+ else:
222
+ # compute attentions
223
+ if self.split_head_dim:
224
+ attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
225
+ else:
226
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
227
+
228
+ attention_scores = attention_scores * self.scale
229
+ attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
230
+
231
+ # attend to values
232
+ if self.split_head_dim:
233
+ hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
234
+ b = hidden_states.shape[0]
235
+ hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
236
+ else:
237
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
238
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
239
+
240
+ hidden_states = self.proj_attn(hidden_states)
241
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
242
+
243
+
244
+ class FlaxBasicTransformerBlock(nn.Module):
245
+ r"""
246
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
247
+ https://arxiv.org/abs/1706.03762
248
+
249
+
250
+ Parameters:
251
+ dim (:obj:`int`):
252
+ Inner hidden states dimension
253
+ n_heads (:obj:`int`):
254
+ Number of heads
255
+ d_head (:obj:`int`):
256
+ Hidden states dimension inside each head
257
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
258
+ Dropout rate
259
+ only_cross_attention (`bool`, defaults to `False`):
260
+ Whether to only apply cross attention.
261
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
262
+ Parameters `dtype`
263
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
264
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
265
+ split_head_dim (`bool`, *optional*, defaults to `False`):
266
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
267
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
268
+ """
269
+
270
+ dim: int
271
+ n_heads: int
272
+ d_head: int
273
+ dropout: float = 0.0
274
+ only_cross_attention: bool = False
275
+ dtype: jnp.dtype = jnp.float32
276
+ use_memory_efficient_attention: bool = False
277
+ split_head_dim: bool = False
278
+
279
+ def setup(self):
280
+ # self attention (or cross_attention if only_cross_attention is True)
281
+ self.attn1 = FlaxAttention(
282
+ self.dim,
283
+ self.n_heads,
284
+ self.d_head,
285
+ self.dropout,
286
+ self.use_memory_efficient_attention,
287
+ self.split_head_dim,
288
+ dtype=self.dtype,
289
+ )
290
+ # cross attention
291
+ self.attn2 = FlaxAttention(
292
+ self.dim,
293
+ self.n_heads,
294
+ self.d_head,
295
+ self.dropout,
296
+ self.use_memory_efficient_attention,
297
+ self.split_head_dim,
298
+ dtype=self.dtype,
299
+ )
300
+ self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
301
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
302
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
303
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
304
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
305
+
306
+ def __call__(self, hidden_states, context, deterministic=True):
307
+ # self attention
308
+ residual = hidden_states
309
+ if self.only_cross_attention:
310
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
311
+ else:
312
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
313
+ hidden_states = hidden_states + residual
314
+
315
+ # cross attention
316
+ residual = hidden_states
317
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
318
+ hidden_states = hidden_states + residual
319
+
320
+ # feed forward
321
+ residual = hidden_states
322
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
323
+ hidden_states = hidden_states + residual
324
+
325
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
326
+
327
+
328
+ class FlaxTransformer2DModel(nn.Module):
329
+ r"""
330
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
331
+ https://arxiv.org/pdf/1506.02025.pdf
332
+
333
+
334
+ Parameters:
335
+ in_channels (:obj:`int`):
336
+ Input number of channels
337
+ n_heads (:obj:`int`):
338
+ Number of heads
339
+ d_head (:obj:`int`):
340
+ Hidden states dimension inside each head
341
+ depth (:obj:`int`, *optional*, defaults to 1):
342
+ Number of transformers block
343
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
344
+ Dropout rate
345
+ use_linear_projection (`bool`, defaults to `False`): tbd
346
+ only_cross_attention (`bool`, defaults to `False`): tbd
347
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
348
+ Parameters `dtype`
349
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
350
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
351
+ split_head_dim (`bool`, *optional*, defaults to `False`):
352
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
353
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
354
+ """
355
+
356
+ in_channels: int
357
+ n_heads: int
358
+ d_head: int
359
+ depth: int = 1
360
+ dropout: float = 0.0
361
+ use_linear_projection: bool = False
362
+ only_cross_attention: bool = False
363
+ dtype: jnp.dtype = jnp.float32
364
+ use_memory_efficient_attention: bool = False
365
+ split_head_dim: bool = False
366
+
367
+ def setup(self):
368
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
369
+
370
+ inner_dim = self.n_heads * self.d_head
371
+ if self.use_linear_projection:
372
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
373
+ else:
374
+ self.proj_in = nn.Conv(
375
+ inner_dim,
376
+ kernel_size=(1, 1),
377
+ strides=(1, 1),
378
+ padding="VALID",
379
+ dtype=self.dtype,
380
+ )
381
+
382
+ self.transformer_blocks = [
383
+ FlaxBasicTransformerBlock(
384
+ inner_dim,
385
+ self.n_heads,
386
+ self.d_head,
387
+ dropout=self.dropout,
388
+ only_cross_attention=self.only_cross_attention,
389
+ dtype=self.dtype,
390
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
391
+ split_head_dim=self.split_head_dim,
392
+ )
393
+ for _ in range(self.depth)
394
+ ]
395
+
396
+ if self.use_linear_projection:
397
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
398
+ else:
399
+ self.proj_out = nn.Conv(
400
+ inner_dim,
401
+ kernel_size=(1, 1),
402
+ strides=(1, 1),
403
+ padding="VALID",
404
+ dtype=self.dtype,
405
+ )
406
+
407
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
408
+
409
+ def __call__(self, hidden_states, context, deterministic=True):
410
+ batch, height, width, channels = hidden_states.shape
411
+ residual = hidden_states
412
+ hidden_states = self.norm(hidden_states)
413
+ if self.use_linear_projection:
414
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
415
+ hidden_states = self.proj_in(hidden_states)
416
+ else:
417
+ hidden_states = self.proj_in(hidden_states)
418
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
419
+
420
+ for transformer_block in self.transformer_blocks:
421
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
422
+
423
+ if self.use_linear_projection:
424
+ hidden_states = self.proj_out(hidden_states)
425
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
426
+ else:
427
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
428
+ hidden_states = self.proj_out(hidden_states)
429
+
430
+ hidden_states = hidden_states + residual
431
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
432
+
433
+
434
+ class FlaxFeedForward(nn.Module):
435
+ r"""
436
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
437
+ [`FeedForward`] class, with the following simplifications:
438
+ - The activation function is currently hardcoded to a gated linear unit from:
439
+ https://arxiv.org/abs/2002.05202
440
+ - `dim_out` is equal to `dim`.
441
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
442
+
443
+ Parameters:
444
+ dim (:obj:`int`):
445
+ Inner hidden states dimension
446
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
447
+ Dropout rate
448
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
449
+ Parameters `dtype`
450
+ """
451
+
452
+ dim: int
453
+ dropout: float = 0.0
454
+ dtype: jnp.dtype = jnp.float32
455
+
456
+ def setup(self):
457
+ # The second linear layer needs to be called
458
+ # net_2 for now to match the index of the Sequential layer
459
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
460
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
461
+
462
+ def __call__(self, hidden_states, deterministic=True):
463
+ hidden_states = self.net_0(hidden_states, deterministic=deterministic)
464
+ hidden_states = self.net_2(hidden_states)
465
+ return hidden_states
466
+
467
+
468
+ class FlaxGEGLU(nn.Module):
469
+ r"""
470
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
471
+ https://arxiv.org/abs/2002.05202.
472
+
473
+ Parameters:
474
+ dim (:obj:`int`):
475
+ Input hidden states dimension
476
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
477
+ Dropout rate
478
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
479
+ Parameters `dtype`
480
+ """
481
+
482
+ dim: int
483
+ dropout: float = 0.0
484
+ dtype: jnp.dtype = jnp.float32
485
+
486
+ def setup(self):
487
+ inner_dim = self.dim * 4
488
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
489
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
490
+
491
+ def __call__(self, hidden_states, deterministic=True):
492
+ hidden_states = self.proj(hidden_states)
493
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
494
+ return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
icedit/diffusers/models/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
icedit/diffusers/models/autoencoders/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .autoencoder_asym_kl import AsymmetricAutoencoderKL
2
+ from .autoencoder_dc import AutoencoderDC
3
+ from .autoencoder_kl import AutoencoderKL
4
+ from .autoencoder_kl_allegro import AutoencoderKLAllegro
5
+ from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
6
+ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
7
+ from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
8
+ from .autoencoder_kl_mochi import AutoencoderKLMochi
9
+ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
10
+ from .autoencoder_oobleck import AutoencoderOobleck
11
+ from .autoencoder_tiny import AutoencoderTiny
12
+ from .consistency_decoder_vae import ConsistencyDecoderVAE
13
+ from .vq_model import VQModel
icedit/diffusers/models/autoencoders/autoencoder_asym_kl.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...utils.accelerate_utils import apply_forward_hook
21
+ from ..modeling_outputs import AutoencoderKLOutput
22
+ from ..modeling_utils import ModelMixin
23
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
24
+
25
+
26
+ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
27
+ r"""
28
+ Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
29
+ for encoding images into latents and decoding latent representations into images.
30
+
31
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
32
+ for all models (such as downloading or saving).
33
+
34
+ Parameters:
35
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
36
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
37
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
38
+ Tuple of downsample block types.
39
+ down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
40
+ Tuple of down block output channels.
41
+ layers_per_down_block (`int`, *optional*, defaults to `1`):
42
+ Number layers for down block.
43
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
44
+ Tuple of upsample block types.
45
+ up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
46
+ Tuple of up block output channels.
47
+ layers_per_up_block (`int`, *optional*, defaults to `1`):
48
+ Number layers for up block.
49
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
50
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
51
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
52
+ norm_num_groups (`int`, *optional*, defaults to `32`):
53
+ Number of groups to use for the first normalization layer in ResNet blocks.
54
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
55
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
56
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
57
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
58
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
59
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
60
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
61
+ """
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
69
+ down_block_out_channels: Tuple[int, ...] = (64,),
70
+ layers_per_down_block: int = 1,
71
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
72
+ up_block_out_channels: Tuple[int, ...] = (64,),
73
+ layers_per_up_block: int = 1,
74
+ act_fn: str = "silu",
75
+ latent_channels: int = 4,
76
+ norm_num_groups: int = 32,
77
+ sample_size: int = 32,
78
+ scaling_factor: float = 0.18215,
79
+ ) -> None:
80
+ super().__init__()
81
+
82
+ # pass init params to Encoder
83
+ self.encoder = Encoder(
84
+ in_channels=in_channels,
85
+ out_channels=latent_channels,
86
+ down_block_types=down_block_types,
87
+ block_out_channels=down_block_out_channels,
88
+ layers_per_block=layers_per_down_block,
89
+ act_fn=act_fn,
90
+ norm_num_groups=norm_num_groups,
91
+ double_z=True,
92
+ )
93
+
94
+ # pass init params to Decoder
95
+ self.decoder = MaskConditionDecoder(
96
+ in_channels=latent_channels,
97
+ out_channels=out_channels,
98
+ up_block_types=up_block_types,
99
+ block_out_channels=up_block_out_channels,
100
+ layers_per_block=layers_per_up_block,
101
+ act_fn=act_fn,
102
+ norm_num_groups=norm_num_groups,
103
+ )
104
+
105
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
106
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
107
+
108
+ self.use_slicing = False
109
+ self.use_tiling = False
110
+
111
+ self.register_to_config(block_out_channels=up_block_out_channels)
112
+ self.register_to_config(force_upcast=False)
113
+
114
+ @apply_forward_hook
115
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
116
+ h = self.encoder(x)
117
+ moments = self.quant_conv(h)
118
+ posterior = DiagonalGaussianDistribution(moments)
119
+
120
+ if not return_dict:
121
+ return (posterior,)
122
+
123
+ return AutoencoderKLOutput(latent_dist=posterior)
124
+
125
+ def _decode(
126
+ self,
127
+ z: torch.Tensor,
128
+ image: Optional[torch.Tensor] = None,
129
+ mask: Optional[torch.Tensor] = None,
130
+ return_dict: bool = True,
131
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
132
+ z = self.post_quant_conv(z)
133
+ dec = self.decoder(z, image, mask)
134
+
135
+ if not return_dict:
136
+ return (dec,)
137
+
138
+ return DecoderOutput(sample=dec)
139
+
140
+ @apply_forward_hook
141
+ def decode(
142
+ self,
143
+ z: torch.Tensor,
144
+ generator: Optional[torch.Generator] = None,
145
+ image: Optional[torch.Tensor] = None,
146
+ mask: Optional[torch.Tensor] = None,
147
+ return_dict: bool = True,
148
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
149
+ decoded = self._decode(z, image, mask).sample
150
+
151
+ if not return_dict:
152
+ return (decoded,)
153
+
154
+ return DecoderOutput(sample=decoded)
155
+
156
+ def forward(
157
+ self,
158
+ sample: torch.Tensor,
159
+ mask: Optional[torch.Tensor] = None,
160
+ sample_posterior: bool = False,
161
+ return_dict: bool = True,
162
+ generator: Optional[torch.Generator] = None,
163
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
164
+ r"""
165
+ Args:
166
+ sample (`torch.Tensor`): Input sample.
167
+ mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask.
168
+ sample_posterior (`bool`, *optional*, defaults to `False`):
169
+ Whether to sample from the posterior.
170
+ return_dict (`bool`, *optional*, defaults to `True`):
171
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
172
+ """
173
+ x = sample
174
+ posterior = self.encode(x).latent_dist
175
+ if sample_posterior:
176
+ z = posterior.sample(generator=generator)
177
+ else:
178
+ z = posterior.mode()
179
+ dec = self.decode(z, generator, sample, mask).sample
180
+
181
+ if not return_dict:
182
+ return (dec,)
183
+
184
+ return DecoderOutput(sample=dec)
icedit/diffusers/models/autoencoders/autoencoder_dc.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 MIT, Tsinghua University, NVIDIA CORPORATION and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FromOriginalModelMixin
24
+ from ...utils.accelerate_utils import apply_forward_hook
25
+ from ..activations import get_activation
26
+ from ..attention_processor import SanaMultiscaleLinearAttention
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import RMSNorm, get_normalization
29
+ from ..transformers.sana_transformer import GLUMBConv
30
+ from .vae import DecoderOutput, EncoderOutput
31
+
32
+
33
+ class ResBlock(nn.Module):
34
+ def __init__(
35
+ self,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ norm_type: str = "batch_norm",
39
+ act_fn: str = "relu6",
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ self.norm_type = norm_type
44
+
45
+ self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
46
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
47
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
48
+ self.norm = get_normalization(norm_type, out_channels)
49
+
50
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
51
+ residual = hidden_states
52
+ hidden_states = self.conv1(hidden_states)
53
+ hidden_states = self.nonlinearity(hidden_states)
54
+ hidden_states = self.conv2(hidden_states)
55
+
56
+ if self.norm_type == "rms_norm":
57
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
58
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
59
+ else:
60
+ hidden_states = self.norm(hidden_states)
61
+
62
+ return hidden_states + residual
63
+
64
+
65
+ class EfficientViTBlock(nn.Module):
66
+ def __init__(
67
+ self,
68
+ in_channels: int,
69
+ mult: float = 1.0,
70
+ attention_head_dim: int = 32,
71
+ qkv_multiscales: Tuple[int, ...] = (5,),
72
+ norm_type: str = "batch_norm",
73
+ ) -> None:
74
+ super().__init__()
75
+
76
+ self.attn = SanaMultiscaleLinearAttention(
77
+ in_channels=in_channels,
78
+ out_channels=in_channels,
79
+ mult=mult,
80
+ attention_head_dim=attention_head_dim,
81
+ norm_type=norm_type,
82
+ kernel_sizes=qkv_multiscales,
83
+ residual_connection=True,
84
+ )
85
+
86
+ self.conv_out = GLUMBConv(
87
+ in_channels=in_channels,
88
+ out_channels=in_channels,
89
+ norm_type="rms_norm",
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ x = self.attn(x)
94
+ x = self.conv_out(x)
95
+ return x
96
+
97
+
98
+ def get_block(
99
+ block_type: str,
100
+ in_channels: int,
101
+ out_channels: int,
102
+ attention_head_dim: int,
103
+ norm_type: str,
104
+ act_fn: str,
105
+ qkv_mutliscales: Tuple[int] = (),
106
+ ):
107
+ if block_type == "ResBlock":
108
+ block = ResBlock(in_channels, out_channels, norm_type, act_fn)
109
+
110
+ elif block_type == "EfficientViTBlock":
111
+ block = EfficientViTBlock(
112
+ in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
113
+ )
114
+
115
+ else:
116
+ raise ValueError(f"Block with {block_type=} is not supported.")
117
+
118
+ return block
119
+
120
+
121
+ class DCDownBlock2d(nn.Module):
122
+ def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
123
+ super().__init__()
124
+
125
+ self.downsample = downsample
126
+ self.factor = 2
127
+ self.stride = 1 if downsample else 2
128
+ self.group_size = in_channels * self.factor**2 // out_channels
129
+ self.shortcut = shortcut
130
+
131
+ out_ratio = self.factor**2
132
+ if downsample:
133
+ assert out_channels % out_ratio == 0
134
+ out_channels = out_channels // out_ratio
135
+
136
+ self.conv = nn.Conv2d(
137
+ in_channels,
138
+ out_channels,
139
+ kernel_size=3,
140
+ stride=self.stride,
141
+ padding=1,
142
+ )
143
+
144
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
+ x = self.conv(hidden_states)
146
+ if self.downsample:
147
+ x = F.pixel_unshuffle(x, self.factor)
148
+
149
+ if self.shortcut:
150
+ y = F.pixel_unshuffle(hidden_states, self.factor)
151
+ y = y.unflatten(1, (-1, self.group_size))
152
+ y = y.mean(dim=2)
153
+ hidden_states = x + y
154
+ else:
155
+ hidden_states = x
156
+
157
+ return hidden_states
158
+
159
+
160
+ class DCUpBlock2d(nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels: int,
164
+ out_channels: int,
165
+ interpolate: bool = False,
166
+ shortcut: bool = True,
167
+ interpolation_mode: str = "nearest",
168
+ ) -> None:
169
+ super().__init__()
170
+
171
+ self.interpolate = interpolate
172
+ self.interpolation_mode = interpolation_mode
173
+ self.shortcut = shortcut
174
+ self.factor = 2
175
+ self.repeats = out_channels * self.factor**2 // in_channels
176
+
177
+ out_ratio = self.factor**2
178
+
179
+ if not interpolate:
180
+ out_channels = out_channels * out_ratio
181
+
182
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
183
+
184
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
185
+ if self.interpolate:
186
+ x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
187
+ x = self.conv(x)
188
+ else:
189
+ x = self.conv(hidden_states)
190
+ x = F.pixel_shuffle(x, self.factor)
191
+
192
+ if self.shortcut:
193
+ y = hidden_states.repeat_interleave(self.repeats, dim=1)
194
+ y = F.pixel_shuffle(y, self.factor)
195
+ hidden_states = x + y
196
+ else:
197
+ hidden_states = x
198
+
199
+ return hidden_states
200
+
201
+
202
+ class Encoder(nn.Module):
203
+ def __init__(
204
+ self,
205
+ in_channels: int,
206
+ latent_channels: int,
207
+ attention_head_dim: int = 32,
208
+ block_type: Union[str, Tuple[str]] = "ResBlock",
209
+ block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
210
+ layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
211
+ qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
212
+ downsample_block_type: str = "pixel_unshuffle",
213
+ out_shortcut: bool = True,
214
+ ):
215
+ super().__init__()
216
+
217
+ num_blocks = len(block_out_channels)
218
+
219
+ if isinstance(block_type, str):
220
+ block_type = (block_type,) * num_blocks
221
+
222
+ if layers_per_block[0] > 0:
223
+ self.conv_in = nn.Conv2d(
224
+ in_channels,
225
+ block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
226
+ kernel_size=3,
227
+ stride=1,
228
+ padding=1,
229
+ )
230
+ else:
231
+ self.conv_in = DCDownBlock2d(
232
+ in_channels=in_channels,
233
+ out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
234
+ downsample=downsample_block_type == "pixel_unshuffle",
235
+ shortcut=False,
236
+ )
237
+
238
+ down_blocks = []
239
+ for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
240
+ down_block_list = []
241
+
242
+ for _ in range(num_layers):
243
+ block = get_block(
244
+ block_type[i],
245
+ out_channel,
246
+ out_channel,
247
+ attention_head_dim=attention_head_dim,
248
+ norm_type="rms_norm",
249
+ act_fn="silu",
250
+ qkv_mutliscales=qkv_multiscales[i],
251
+ )
252
+ down_block_list.append(block)
253
+
254
+ if i < num_blocks - 1 and num_layers > 0:
255
+ downsample_block = DCDownBlock2d(
256
+ in_channels=out_channel,
257
+ out_channels=block_out_channels[i + 1],
258
+ downsample=downsample_block_type == "pixel_unshuffle",
259
+ shortcut=True,
260
+ )
261
+ down_block_list.append(downsample_block)
262
+
263
+ down_blocks.append(nn.Sequential(*down_block_list))
264
+
265
+ self.down_blocks = nn.ModuleList(down_blocks)
266
+
267
+ self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
268
+
269
+ self.out_shortcut = out_shortcut
270
+ if out_shortcut:
271
+ self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
272
+
273
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
274
+ hidden_states = self.conv_in(hidden_states)
275
+ for down_block in self.down_blocks:
276
+ hidden_states = down_block(hidden_states)
277
+
278
+ if self.out_shortcut:
279
+ x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
280
+ x = x.mean(dim=2)
281
+ hidden_states = self.conv_out(hidden_states) + x
282
+ else:
283
+ hidden_states = self.conv_out(hidden_states)
284
+
285
+ return hidden_states
286
+
287
+
288
+ class Decoder(nn.Module):
289
+ def __init__(
290
+ self,
291
+ in_channels: int,
292
+ latent_channels: int,
293
+ attention_head_dim: int = 32,
294
+ block_type: Union[str, Tuple[str]] = "ResBlock",
295
+ block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
296
+ layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
297
+ qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
298
+ norm_type: Union[str, Tuple[str]] = "rms_norm",
299
+ act_fn: Union[str, Tuple[str]] = "silu",
300
+ upsample_block_type: str = "pixel_shuffle",
301
+ in_shortcut: bool = True,
302
+ ):
303
+ super().__init__()
304
+
305
+ num_blocks = len(block_out_channels)
306
+
307
+ if isinstance(block_type, str):
308
+ block_type = (block_type,) * num_blocks
309
+ if isinstance(norm_type, str):
310
+ norm_type = (norm_type,) * num_blocks
311
+ if isinstance(act_fn, str):
312
+ act_fn = (act_fn,) * num_blocks
313
+
314
+ self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
315
+
316
+ self.in_shortcut = in_shortcut
317
+ if in_shortcut:
318
+ self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
319
+
320
+ up_blocks = []
321
+ for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
322
+ up_block_list = []
323
+
324
+ if i < num_blocks - 1 and num_layers > 0:
325
+ upsample_block = DCUpBlock2d(
326
+ block_out_channels[i + 1],
327
+ out_channel,
328
+ interpolate=upsample_block_type == "interpolate",
329
+ shortcut=True,
330
+ )
331
+ up_block_list.append(upsample_block)
332
+
333
+ for _ in range(num_layers):
334
+ block = get_block(
335
+ block_type[i],
336
+ out_channel,
337
+ out_channel,
338
+ attention_head_dim=attention_head_dim,
339
+ norm_type=norm_type[i],
340
+ act_fn=act_fn[i],
341
+ qkv_mutliscales=qkv_multiscales[i],
342
+ )
343
+ up_block_list.append(block)
344
+
345
+ up_blocks.insert(0, nn.Sequential(*up_block_list))
346
+
347
+ self.up_blocks = nn.ModuleList(up_blocks)
348
+
349
+ channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
350
+
351
+ self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
352
+ self.conv_act = nn.ReLU()
353
+ self.conv_out = None
354
+
355
+ if layers_per_block[0] > 0:
356
+ self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1)
357
+ else:
358
+ self.conv_out = DCUpBlock2d(
359
+ channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
360
+ )
361
+
362
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
363
+ if self.in_shortcut:
364
+ x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
365
+ hidden_states = self.conv_in(hidden_states) + x
366
+ else:
367
+ hidden_states = self.conv_in(hidden_states)
368
+
369
+ for up_block in reversed(self.up_blocks):
370
+ hidden_states = up_block(hidden_states)
371
+
372
+ hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
373
+ hidden_states = self.conv_act(hidden_states)
374
+ hidden_states = self.conv_out(hidden_states)
375
+ return hidden_states
376
+
377
+
378
+ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
379
+ r"""
380
+ An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
381
+ [SANA](https://arxiv.org/abs/2410.10629).
382
+
383
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
384
+ for all models (such as downloading or saving).
385
+
386
+ Args:
387
+ in_channels (`int`, defaults to `3`):
388
+ The number of input channels in samples.
389
+ latent_channels (`int`, defaults to `32`):
390
+ The number of channels in the latent space representation.
391
+ encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
392
+ The type(s) of block to use in the encoder.
393
+ decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
394
+ The type(s) of block to use in the decoder.
395
+ encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
396
+ The number of output channels for each block in the encoder.
397
+ decoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
398
+ The number of output channels for each block in the decoder.
399
+ encoder_layers_per_block (`Tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
400
+ The number of layers per block in the encoder.
401
+ decoder_layers_per_block (`Tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
402
+ The number of layers per block in the decoder.
403
+ encoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
404
+ Multi-scale configurations for the encoder's QKV (query-key-value) transformations.
405
+ decoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
406
+ Multi-scale configurations for the decoder's QKV (query-key-value) transformations.
407
+ upsample_block_type (`str`, defaults to `"pixel_shuffle"`):
408
+ The type of block to use for upsampling in the decoder.
409
+ downsample_block_type (`str`, defaults to `"pixel_unshuffle"`):
410
+ The type of block to use for downsampling in the encoder.
411
+ decoder_norm_types (`Union[str, Tuple[str]]`, defaults to `"rms_norm"`):
412
+ The normalization type(s) to use in the decoder.
413
+ decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
414
+ The activation function(s) to use in the decoder.
415
+ scaling_factor (`float`, defaults to `1.0`):
416
+ The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
417
+ space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
418
+ z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back
419
+ to the original scale with the formula: `z = 1 / scaling_factor * z`.
420
+ """
421
+
422
+ _supports_gradient_checkpointing = False
423
+
424
+ @register_to_config
425
+ def __init__(
426
+ self,
427
+ in_channels: int = 3,
428
+ latent_channels: int = 32,
429
+ attention_head_dim: int = 32,
430
+ encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
431
+ decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
432
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
433
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
434
+ encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
435
+ decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
436
+ encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
437
+ decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
438
+ upsample_block_type: str = "pixel_shuffle",
439
+ downsample_block_type: str = "pixel_unshuffle",
440
+ decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
441
+ decoder_act_fns: Union[str, Tuple[str]] = "silu",
442
+ scaling_factor: float = 1.0,
443
+ ) -> None:
444
+ super().__init__()
445
+
446
+ self.encoder = Encoder(
447
+ in_channels=in_channels,
448
+ latent_channels=latent_channels,
449
+ attention_head_dim=attention_head_dim,
450
+ block_type=encoder_block_types,
451
+ block_out_channels=encoder_block_out_channels,
452
+ layers_per_block=encoder_layers_per_block,
453
+ qkv_multiscales=encoder_qkv_multiscales,
454
+ downsample_block_type=downsample_block_type,
455
+ )
456
+ self.decoder = Decoder(
457
+ in_channels=in_channels,
458
+ latent_channels=latent_channels,
459
+ attention_head_dim=attention_head_dim,
460
+ block_type=decoder_block_types,
461
+ block_out_channels=decoder_block_out_channels,
462
+ layers_per_block=decoder_layers_per_block,
463
+ qkv_multiscales=decoder_qkv_multiscales,
464
+ norm_type=decoder_norm_types,
465
+ act_fn=decoder_act_fns,
466
+ upsample_block_type=upsample_block_type,
467
+ )
468
+
469
+ self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
470
+ self.temporal_compression_ratio = 1
471
+
472
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
473
+ # to perform decoding of a single video latent at a time.
474
+ self.use_slicing = False
475
+
476
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
477
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
478
+ # intermediate tiles together, the memory requirement can be lowered.
479
+ self.use_tiling = False
480
+
481
+ # The minimal tile height and width for spatial tiling to be used
482
+ self.tile_sample_min_height = 512
483
+ self.tile_sample_min_width = 512
484
+
485
+ # The minimal distance between two spatial tiles
486
+ self.tile_sample_stride_height = 448
487
+ self.tile_sample_stride_width = 448
488
+
489
+ def enable_tiling(
490
+ self,
491
+ tile_sample_min_height: Optional[int] = None,
492
+ tile_sample_min_width: Optional[int] = None,
493
+ tile_sample_stride_height: Optional[float] = None,
494
+ tile_sample_stride_width: Optional[float] = None,
495
+ ) -> None:
496
+ r"""
497
+ Enable tiled AE decoding. When this option is enabled, the AE will split the input tensor into tiles to compute
498
+ decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
499
+ processing larger images.
500
+
501
+ Args:
502
+ tile_sample_min_height (`int`, *optional*):
503
+ The minimum height required for a sample to be separated into tiles across the height dimension.
504
+ tile_sample_min_width (`int`, *optional*):
505
+ The minimum width required for a sample to be separated into tiles across the width dimension.
506
+ tile_sample_stride_height (`int`, *optional*):
507
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
508
+ no tiling artifacts produced across the height dimension.
509
+ tile_sample_stride_width (`int`, *optional*):
510
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
511
+ artifacts produced across the width dimension.
512
+ """
513
+ self.use_tiling = True
514
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
515
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
516
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
517
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
518
+
519
+ def disable_tiling(self) -> None:
520
+ r"""
521
+ Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
522
+ decoding in one step.
523
+ """
524
+ self.use_tiling = False
525
+
526
+ def enable_slicing(self) -> None:
527
+ r"""
528
+ Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
529
+ decoding in several steps. This is useful to save some memory and allow larger batch sizes.
530
+ """
531
+ self.use_slicing = True
532
+
533
+ def disable_slicing(self) -> None:
534
+ r"""
535
+ Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
536
+ decoding in one step.
537
+ """
538
+ self.use_slicing = False
539
+
540
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
541
+ batch_size, num_channels, height, width = x.shape
542
+
543
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
544
+ return self.tiled_encode(x, return_dict=False)[0]
545
+
546
+ encoded = self.encoder(x)
547
+
548
+ return encoded
549
+
550
+ @apply_forward_hook
551
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]:
552
+ r"""
553
+ Encode a batch of images into latents.
554
+
555
+ Args:
556
+ x (`torch.Tensor`): Input batch of images.
557
+ return_dict (`bool`, defaults to `True`):
558
+ Whether to return a [`~models.vae.EncoderOutput`] instead of a plain tuple.
559
+
560
+ Returns:
561
+ The latent representations of the encoded videos. If `return_dict` is True, a
562
+ [`~models.vae.EncoderOutput`] is returned, otherwise a plain `tuple` is returned.
563
+ """
564
+ if self.use_slicing and x.shape[0] > 1:
565
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
566
+ encoded = torch.cat(encoded_slices)
567
+ else:
568
+ encoded = self._encode(x)
569
+
570
+ if not return_dict:
571
+ return (encoded,)
572
+ return EncoderOutput(latent=encoded)
573
+
574
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
575
+ batch_size, num_channels, height, width = z.shape
576
+
577
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
578
+ return self.tiled_decode(z, return_dict=False)[0]
579
+
580
+ decoded = self.decoder(z)
581
+
582
+ return decoded
583
+
584
+ @apply_forward_hook
585
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
586
+ r"""
587
+ Decode a batch of images.
588
+
589
+ Args:
590
+ z (`torch.Tensor`): Input batch of latent vectors.
591
+ return_dict (`bool`, defaults to `True`):
592
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
593
+
594
+ Returns:
595
+ [`~models.vae.DecoderOutput`] or `tuple`:
596
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
597
+ returned.
598
+ """
599
+ if self.use_slicing and z.size(0) > 1:
600
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
601
+ decoded = torch.cat(decoded_slices)
602
+ else:
603
+ decoded = self._decode(z)
604
+
605
+ if not return_dict:
606
+ return (decoded,)
607
+ return DecoderOutput(sample=decoded)
608
+
609
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
610
+ raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
611
+
612
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
613
+ raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
614
+
615
+ def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
616
+ encoded = self.encode(sample, return_dict=False)[0]
617
+ decoded = self.decode(encoded, return_dict=False)[0]
618
+ if not return_dict:
619
+ return (decoded,)
620
+ return DecoderOutput(sample=decoded)
icedit/diffusers/models/autoencoders/autoencoder_kl.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...loaders import PeftAdapterMixin
21
+ from ...loaders.single_file_model import FromOriginalModelMixin
22
+ from ...utils import deprecate
23
+ from ...utils.accelerate_utils import apply_forward_hook
24
+ from ..attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ Attention,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ FusedAttnProcessor2_0,
32
+ )
33
+ from ..modeling_outputs import AutoencoderKLOutput
34
+ from ..modeling_utils import ModelMixin
35
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
36
+
37
+
38
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
39
+ r"""
40
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
41
+
42
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
43
+ for all models (such as downloading or saving).
44
+
45
+ Parameters:
46
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
47
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
48
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
49
+ Tuple of downsample block types.
50
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
51
+ Tuple of upsample block types.
52
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
53
+ Tuple of block output channels.
54
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
55
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
56
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
57
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
58
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
59
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
60
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
61
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
62
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
63
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
64
+ force_upcast (`bool`, *optional*, default to `True`):
65
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
66
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
67
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
68
+ mid_block_add_attention (`bool`, *optional*, default to `True`):
69
+ If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
70
+ mid_block will only have resnet blocks
71
+ """
72
+
73
+ _supports_gradient_checkpointing = True
74
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
75
+
76
+ @register_to_config
77
+ def __init__(
78
+ self,
79
+ in_channels: int = 3,
80
+ out_channels: int = 3,
81
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
82
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
83
+ block_out_channels: Tuple[int] = (64,),
84
+ layers_per_block: int = 1,
85
+ act_fn: str = "silu",
86
+ latent_channels: int = 4,
87
+ norm_num_groups: int = 32,
88
+ sample_size: int = 32,
89
+ scaling_factor: float = 0.18215,
90
+ shift_factor: Optional[float] = None,
91
+ latents_mean: Optional[Tuple[float]] = None,
92
+ latents_std: Optional[Tuple[float]] = None,
93
+ force_upcast: float = True,
94
+ use_quant_conv: bool = True,
95
+ use_post_quant_conv: bool = True,
96
+ mid_block_add_attention: bool = True,
97
+ ):
98
+ super().__init__()
99
+
100
+ # pass init params to Encoder
101
+ self.encoder = Encoder(
102
+ in_channels=in_channels,
103
+ out_channels=latent_channels,
104
+ down_block_types=down_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ act_fn=act_fn,
108
+ norm_num_groups=norm_num_groups,
109
+ double_z=True,
110
+ mid_block_add_attention=mid_block_add_attention,
111
+ )
112
+
113
+ # pass init params to Decoder
114
+ self.decoder = Decoder(
115
+ in_channels=latent_channels,
116
+ out_channels=out_channels,
117
+ up_block_types=up_block_types,
118
+ block_out_channels=block_out_channels,
119
+ layers_per_block=layers_per_block,
120
+ norm_num_groups=norm_num_groups,
121
+ act_fn=act_fn,
122
+ mid_block_add_attention=mid_block_add_attention,
123
+ )
124
+
125
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
126
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
127
+
128
+ self.use_slicing = False
129
+ self.use_tiling = False
130
+
131
+ # only relevant if vae tiling is enabled
132
+ self.tile_sample_min_size = self.config.sample_size
133
+ sample_size = (
134
+ self.config.sample_size[0]
135
+ if isinstance(self.config.sample_size, (list, tuple))
136
+ else self.config.sample_size
137
+ )
138
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
139
+ self.tile_overlap_factor = 0.25
140
+
141
+ def _set_gradient_checkpointing(self, module, value=False):
142
+ if isinstance(module, (Encoder, Decoder)):
143
+ module.gradient_checkpointing = value
144
+
145
+ def enable_tiling(self, use_tiling: bool = True):
146
+ r"""
147
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
148
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
149
+ processing larger images.
150
+ """
151
+ self.use_tiling = use_tiling
152
+
153
+ def disable_tiling(self):
154
+ r"""
155
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
156
+ decoding in one step.
157
+ """
158
+ self.enable_tiling(False)
159
+
160
+ def enable_slicing(self):
161
+ r"""
162
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
163
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
164
+ """
165
+ self.use_slicing = True
166
+
167
+ def disable_slicing(self):
168
+ r"""
169
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
170
+ decoding in one step.
171
+ """
172
+ self.use_slicing = False
173
+
174
+ @property
175
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
176
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
177
+ r"""
178
+ Returns:
179
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
180
+ indexed by its weight name.
181
+ """
182
+ # set recursively
183
+ processors = {}
184
+
185
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
186
+ if hasattr(module, "get_processor"):
187
+ processors[f"{name}.processor"] = module.get_processor()
188
+
189
+ for sub_name, child in module.named_children():
190
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
191
+
192
+ return processors
193
+
194
+ for name, module in self.named_children():
195
+ fn_recursive_add_processors(name, module, processors)
196
+
197
+ return processors
198
+
199
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
200
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
201
+ r"""
202
+ Sets the attention processor to use to compute attention.
203
+
204
+ Parameters:
205
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
206
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
207
+ for **all** `Attention` layers.
208
+
209
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
210
+ processor. This is strongly recommended when setting trainable attention processors.
211
+
212
+ """
213
+ count = len(self.attn_processors.keys())
214
+
215
+ if isinstance(processor, dict) and len(processor) != count:
216
+ raise ValueError(
217
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
218
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
219
+ )
220
+
221
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
222
+ if hasattr(module, "set_processor"):
223
+ if not isinstance(processor, dict):
224
+ module.set_processor(processor)
225
+ else:
226
+ module.set_processor(processor.pop(f"{name}.processor"))
227
+
228
+ for sub_name, child in module.named_children():
229
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
230
+
231
+ for name, module in self.named_children():
232
+ fn_recursive_attn_processor(name, module, processor)
233
+
234
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
235
+ def set_default_attn_processor(self):
236
+ """
237
+ Disables custom attention processors and sets the default attention implementation.
238
+ """
239
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
240
+ processor = AttnAddedKVProcessor()
241
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
242
+ processor = AttnProcessor()
243
+ else:
244
+ raise ValueError(
245
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
246
+ )
247
+
248
+ self.set_attn_processor(processor)
249
+
250
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
251
+ batch_size, num_channels, height, width = x.shape
252
+
253
+ if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
254
+ return self._tiled_encode(x)
255
+
256
+ enc = self.encoder(x)
257
+ if self.quant_conv is not None:
258
+ enc = self.quant_conv(enc)
259
+
260
+ return enc
261
+
262
+ @apply_forward_hook
263
+ def encode(
264
+ self, x: torch.Tensor, return_dict: bool = True
265
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
266
+ """
267
+ Encode a batch of images into latents.
268
+
269
+ Args:
270
+ x (`torch.Tensor`): Input batch of images.
271
+ return_dict (`bool`, *optional*, defaults to `True`):
272
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
273
+
274
+ Returns:
275
+ The latent representations of the encoded images. If `return_dict` is True, a
276
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
277
+ """
278
+ if self.use_slicing and x.shape[0] > 1:
279
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
280
+ h = torch.cat(encoded_slices)
281
+ else:
282
+ h = self._encode(x)
283
+
284
+ posterior = DiagonalGaussianDistribution(h)
285
+
286
+ if not return_dict:
287
+ return (posterior,)
288
+
289
+ return AutoencoderKLOutput(latent_dist=posterior)
290
+
291
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
292
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
293
+ return self.tiled_decode(z, return_dict=return_dict)
294
+
295
+ if self.post_quant_conv is not None:
296
+ z = self.post_quant_conv(z)
297
+
298
+ dec = self.decoder(z)
299
+
300
+ if not return_dict:
301
+ return (dec,)
302
+
303
+ return DecoderOutput(sample=dec)
304
+
305
+ @apply_forward_hook
306
+ def decode(
307
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
308
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
309
+ """
310
+ Decode a batch of images.
311
+
312
+ Args:
313
+ z (`torch.Tensor`): Input batch of latent vectors.
314
+ return_dict (`bool`, *optional*, defaults to `True`):
315
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
316
+
317
+ Returns:
318
+ [`~models.vae.DecoderOutput`] or `tuple`:
319
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
320
+ returned.
321
+
322
+ """
323
+ if self.use_slicing and z.shape[0] > 1:
324
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
325
+ decoded = torch.cat(decoded_slices)
326
+ else:
327
+ decoded = self._decode(z).sample
328
+
329
+ if not return_dict:
330
+ return (decoded,)
331
+
332
+ return DecoderOutput(sample=decoded)
333
+
334
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
335
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
336
+ for y in range(blend_extent):
337
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
338
+ return b
339
+
340
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
341
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
342
+ for x in range(blend_extent):
343
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
344
+ return b
345
+
346
+ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
347
+ r"""Encode a batch of images using a tiled encoder.
348
+
349
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
350
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
351
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
352
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
353
+ output, but they should be much less noticeable.
354
+
355
+ Args:
356
+ x (`torch.Tensor`): Input batch of images.
357
+
358
+ Returns:
359
+ `torch.Tensor`:
360
+ The latent representation of the encoded videos.
361
+ """
362
+
363
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
364
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
365
+ row_limit = self.tile_latent_min_size - blend_extent
366
+
367
+ # Split the image into 512x512 tiles and encode them separately.
368
+ rows = []
369
+ for i in range(0, x.shape[2], overlap_size):
370
+ row = []
371
+ for j in range(0, x.shape[3], overlap_size):
372
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
373
+ tile = self.encoder(tile)
374
+ if self.config.use_quant_conv:
375
+ tile = self.quant_conv(tile)
376
+ row.append(tile)
377
+ rows.append(row)
378
+ result_rows = []
379
+ for i, row in enumerate(rows):
380
+ result_row = []
381
+ for j, tile in enumerate(row):
382
+ # blend the above tile and the left tile
383
+ # to the current tile and add the current tile to the result row
384
+ if i > 0:
385
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
386
+ if j > 0:
387
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
388
+ result_row.append(tile[:, :, :row_limit, :row_limit])
389
+ result_rows.append(torch.cat(result_row, dim=3))
390
+
391
+ enc = torch.cat(result_rows, dim=2)
392
+ return enc
393
+
394
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
395
+ r"""Encode a batch of images using a tiled encoder.
396
+
397
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
398
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
399
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
400
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
401
+ output, but they should be much less noticeable.
402
+
403
+ Args:
404
+ x (`torch.Tensor`): Input batch of images.
405
+ return_dict (`bool`, *optional*, defaults to `True`):
406
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
407
+
408
+ Returns:
409
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
410
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
411
+ `tuple` is returned.
412
+ """
413
+ deprecation_message = (
414
+ "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
415
+ "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
416
+ "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
417
+ )
418
+ deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
419
+
420
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
421
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
422
+ row_limit = self.tile_latent_min_size - blend_extent
423
+
424
+ # Split the image into 512x512 tiles and encode them separately.
425
+ rows = []
426
+ for i in range(0, x.shape[2], overlap_size):
427
+ row = []
428
+ for j in range(0, x.shape[3], overlap_size):
429
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
430
+ tile = self.encoder(tile)
431
+ if self.config.use_quant_conv:
432
+ tile = self.quant_conv(tile)
433
+ row.append(tile)
434
+ rows.append(row)
435
+ result_rows = []
436
+ for i, row in enumerate(rows):
437
+ result_row = []
438
+ for j, tile in enumerate(row):
439
+ # blend the above tile and the left tile
440
+ # to the current tile and add the current tile to the result row
441
+ if i > 0:
442
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
443
+ if j > 0:
444
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
445
+ result_row.append(tile[:, :, :row_limit, :row_limit])
446
+ result_rows.append(torch.cat(result_row, dim=3))
447
+
448
+ moments = torch.cat(result_rows, dim=2)
449
+ posterior = DiagonalGaussianDistribution(moments)
450
+
451
+ if not return_dict:
452
+ return (posterior,)
453
+
454
+ return AutoencoderKLOutput(latent_dist=posterior)
455
+
456
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
457
+ r"""
458
+ Decode a batch of images using a tiled decoder.
459
+
460
+ Args:
461
+ z (`torch.Tensor`): Input batch of latent vectors.
462
+ return_dict (`bool`, *optional*, defaults to `True`):
463
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
464
+
465
+ Returns:
466
+ [`~models.vae.DecoderOutput`] or `tuple`:
467
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
468
+ returned.
469
+ """
470
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
471
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
472
+ row_limit = self.tile_sample_min_size - blend_extent
473
+
474
+ # Split z into overlapping 64x64 tiles and decode them separately.
475
+ # The tiles have an overlap to avoid seams between tiles.
476
+ rows = []
477
+ for i in range(0, z.shape[2], overlap_size):
478
+ row = []
479
+ for j in range(0, z.shape[3], overlap_size):
480
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
481
+ if self.config.use_post_quant_conv:
482
+ tile = self.post_quant_conv(tile)
483
+ decoded = self.decoder(tile)
484
+ row.append(decoded)
485
+ rows.append(row)
486
+ result_rows = []
487
+ for i, row in enumerate(rows):
488
+ result_row = []
489
+ for j, tile in enumerate(row):
490
+ # blend the above tile and the left tile
491
+ # to the current tile and add the current tile to the result row
492
+ if i > 0:
493
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
494
+ if j > 0:
495
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
496
+ result_row.append(tile[:, :, :row_limit, :row_limit])
497
+ result_rows.append(torch.cat(result_row, dim=3))
498
+
499
+ dec = torch.cat(result_rows, dim=2)
500
+ if not return_dict:
501
+ return (dec,)
502
+
503
+ return DecoderOutput(sample=dec)
504
+
505
+ def forward(
506
+ self,
507
+ sample: torch.Tensor,
508
+ sample_posterior: bool = False,
509
+ return_dict: bool = True,
510
+ generator: Optional[torch.Generator] = None,
511
+ ) -> Union[DecoderOutput, torch.Tensor]:
512
+ r"""
513
+ Args:
514
+ sample (`torch.Tensor`): Input sample.
515
+ sample_posterior (`bool`, *optional*, defaults to `False`):
516
+ Whether to sample from the posterior.
517
+ return_dict (`bool`, *optional*, defaults to `True`):
518
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
519
+ """
520
+ x = sample
521
+ posterior = self.encode(x).latent_dist
522
+ if sample_posterior:
523
+ z = posterior.sample(generator=generator)
524
+ else:
525
+ z = posterior.mode()
526
+ dec = self.decode(z).sample
527
+
528
+ if not return_dict:
529
+ return (dec,)
530
+
531
+ return DecoderOutput(sample=dec)
532
+
533
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
534
+ def fuse_qkv_projections(self):
535
+ """
536
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
537
+ are fused. For cross-attention modules, key and value projection matrices are fused.
538
+
539
+ <Tip warning={true}>
540
+
541
+ This API is 🧪 experimental.
542
+
543
+ </Tip>
544
+ """
545
+ self.original_attn_processors = None
546
+
547
+ for _, attn_processor in self.attn_processors.items():
548
+ if "Added" in str(attn_processor.__class__.__name__):
549
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
550
+
551
+ self.original_attn_processors = self.attn_processors
552
+
553
+ for module in self.modules():
554
+ if isinstance(module, Attention):
555
+ module.fuse_projections(fuse=True)
556
+
557
+ self.set_attn_processor(FusedAttnProcessor2_0())
558
+
559
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
560
+ def unfuse_qkv_projections(self):
561
+ """Disables the fused QKV projection if enabled.
562
+
563
+ <Tip warning={true}>
564
+
565
+ This API is 🧪 experimental.
566
+
567
+ </Tip>
568
+
569
+ """
570
+ if self.original_attn_processors is not None:
571
+ self.set_attn_processor(self.original_attn_processors)
icedit/diffusers/models/autoencoders/autoencoder_kl_allegro.py ADDED
@@ -0,0 +1,1149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The RhymesAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils.accelerate_utils import apply_forward_hook
24
+ from ..attention_processor import Attention, SpatialNorm
25
+ from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
26
+ from ..downsampling import Downsample2D
27
+ from ..modeling_outputs import AutoencoderKLOutput
28
+ from ..modeling_utils import ModelMixin
29
+ from ..resnet import ResnetBlock2D
30
+ from ..upsampling import Upsample2D
31
+
32
+
33
+ class AllegroTemporalConvLayer(nn.Module):
34
+ r"""
35
+ Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from:
36
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ in_dim: int,
42
+ out_dim: Optional[int] = None,
43
+ dropout: float = 0.0,
44
+ norm_num_groups: int = 32,
45
+ up_sample: bool = False,
46
+ down_sample: bool = False,
47
+ stride: int = 1,
48
+ ) -> None:
49
+ super().__init__()
50
+
51
+ out_dim = out_dim or in_dim
52
+ pad_h = pad_w = int((stride - 1) * 0.5)
53
+ pad_t = 0
54
+
55
+ self.down_sample = down_sample
56
+ self.up_sample = up_sample
57
+
58
+ if down_sample:
59
+ self.conv1 = nn.Sequential(
60
+ nn.GroupNorm(norm_num_groups, in_dim),
61
+ nn.SiLU(),
62
+ nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)),
63
+ )
64
+ elif up_sample:
65
+ self.conv1 = nn.Sequential(
66
+ nn.GroupNorm(norm_num_groups, in_dim),
67
+ nn.SiLU(),
68
+ nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)),
69
+ )
70
+ else:
71
+ self.conv1 = nn.Sequential(
72
+ nn.GroupNorm(norm_num_groups, in_dim),
73
+ nn.SiLU(),
74
+ nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
75
+ )
76
+ self.conv2 = nn.Sequential(
77
+ nn.GroupNorm(norm_num_groups, out_dim),
78
+ nn.SiLU(),
79
+ nn.Dropout(dropout),
80
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
81
+ )
82
+ self.conv3 = nn.Sequential(
83
+ nn.GroupNorm(norm_num_groups, out_dim),
84
+ nn.SiLU(),
85
+ nn.Dropout(dropout),
86
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
87
+ )
88
+ self.conv4 = nn.Sequential(
89
+ nn.GroupNorm(norm_num_groups, out_dim),
90
+ nn.SiLU(),
91
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
92
+ )
93
+
94
+ @staticmethod
95
+ def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor:
96
+ hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2)
97
+ hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2)
98
+ return hidden_states
99
+
100
+ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
101
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
102
+
103
+ if self.down_sample:
104
+ identity = hidden_states[:, :, ::2]
105
+ elif self.up_sample:
106
+ identity = hidden_states.repeat_interleave(2, dim=2)
107
+ else:
108
+ identity = hidden_states
109
+
110
+ if self.down_sample or self.up_sample:
111
+ hidden_states = self.conv1(hidden_states)
112
+ else:
113
+ hidden_states = self._pad_temporal_dim(hidden_states)
114
+ hidden_states = self.conv1(hidden_states)
115
+
116
+ if self.up_sample:
117
+ hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3)
118
+
119
+ hidden_states = self._pad_temporal_dim(hidden_states)
120
+ hidden_states = self.conv2(hidden_states)
121
+
122
+ hidden_states = self._pad_temporal_dim(hidden_states)
123
+ hidden_states = self.conv3(hidden_states)
124
+
125
+ hidden_states = self._pad_temporal_dim(hidden_states)
126
+ hidden_states = self.conv4(hidden_states)
127
+
128
+ hidden_states = identity + hidden_states
129
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
130
+
131
+ return hidden_states
132
+
133
+
134
+ class AllegroDownBlock3D(nn.Module):
135
+ def __init__(
136
+ self,
137
+ in_channels: int,
138
+ out_channels: int,
139
+ dropout: float = 0.0,
140
+ num_layers: int = 1,
141
+ resnet_eps: float = 1e-6,
142
+ resnet_time_scale_shift: str = "default",
143
+ resnet_act_fn: str = "swish",
144
+ resnet_groups: int = 32,
145
+ resnet_pre_norm: bool = True,
146
+ output_scale_factor: float = 1.0,
147
+ spatial_downsample: bool = True,
148
+ temporal_downsample: bool = False,
149
+ downsample_padding: int = 1,
150
+ ):
151
+ super().__init__()
152
+
153
+ resnets = []
154
+ temp_convs = []
155
+
156
+ for i in range(num_layers):
157
+ in_channels = in_channels if i == 0 else out_channels
158
+ resnets.append(
159
+ ResnetBlock2D(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ temb_channels=None,
163
+ eps=resnet_eps,
164
+ groups=resnet_groups,
165
+ dropout=dropout,
166
+ time_embedding_norm=resnet_time_scale_shift,
167
+ non_linearity=resnet_act_fn,
168
+ output_scale_factor=output_scale_factor,
169
+ pre_norm=resnet_pre_norm,
170
+ )
171
+ )
172
+ temp_convs.append(
173
+ AllegroTemporalConvLayer(
174
+ out_channels,
175
+ out_channels,
176
+ dropout=0.1,
177
+ norm_num_groups=resnet_groups,
178
+ )
179
+ )
180
+
181
+ self.resnets = nn.ModuleList(resnets)
182
+ self.temp_convs = nn.ModuleList(temp_convs)
183
+
184
+ if temporal_downsample:
185
+ self.temp_convs_down = AllegroTemporalConvLayer(
186
+ out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3
187
+ )
188
+ self.add_temp_downsample = temporal_downsample
189
+
190
+ if spatial_downsample:
191
+ self.downsamplers = nn.ModuleList(
192
+ [
193
+ Downsample2D(
194
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
195
+ )
196
+ ]
197
+ )
198
+ else:
199
+ self.downsamplers = None
200
+
201
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202
+ batch_size = hidden_states.shape[0]
203
+
204
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
205
+
206
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
207
+ hidden_states = resnet(hidden_states, temb=None)
208
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
209
+
210
+ if self.add_temp_downsample:
211
+ hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size)
212
+
213
+ if self.downsamplers is not None:
214
+ for downsampler in self.downsamplers:
215
+ hidden_states = downsampler(hidden_states)
216
+
217
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
218
+ return hidden_states
219
+
220
+
221
+ class AllegroUpBlock3D(nn.Module):
222
+ def __init__(
223
+ self,
224
+ in_channels: int,
225
+ out_channels: int,
226
+ dropout: float = 0.0,
227
+ num_layers: int = 1,
228
+ resnet_eps: float = 1e-6,
229
+ resnet_time_scale_shift: str = "default", # default, spatial
230
+ resnet_act_fn: str = "swish",
231
+ resnet_groups: int = 32,
232
+ resnet_pre_norm: bool = True,
233
+ output_scale_factor: float = 1.0,
234
+ spatial_upsample: bool = True,
235
+ temporal_upsample: bool = False,
236
+ temb_channels: Optional[int] = None,
237
+ ):
238
+ super().__init__()
239
+
240
+ resnets = []
241
+ temp_convs = []
242
+
243
+ for i in range(num_layers):
244
+ input_channels = in_channels if i == 0 else out_channels
245
+
246
+ resnets.append(
247
+ ResnetBlock2D(
248
+ in_channels=input_channels,
249
+ out_channels=out_channels,
250
+ temb_channels=temb_channels,
251
+ eps=resnet_eps,
252
+ groups=resnet_groups,
253
+ dropout=dropout,
254
+ time_embedding_norm=resnet_time_scale_shift,
255
+ non_linearity=resnet_act_fn,
256
+ output_scale_factor=output_scale_factor,
257
+ pre_norm=resnet_pre_norm,
258
+ )
259
+ )
260
+ temp_convs.append(
261
+ AllegroTemporalConvLayer(
262
+ out_channels,
263
+ out_channels,
264
+ dropout=0.1,
265
+ norm_num_groups=resnet_groups,
266
+ )
267
+ )
268
+
269
+ self.resnets = nn.ModuleList(resnets)
270
+ self.temp_convs = nn.ModuleList(temp_convs)
271
+
272
+ self.add_temp_upsample = temporal_upsample
273
+ if temporal_upsample:
274
+ self.temp_conv_up = AllegroTemporalConvLayer(
275
+ out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3
276
+ )
277
+
278
+ if spatial_upsample:
279
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
280
+ else:
281
+ self.upsamplers = None
282
+
283
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
284
+ batch_size = hidden_states.shape[0]
285
+
286
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
287
+
288
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
289
+ hidden_states = resnet(hidden_states, temb=None)
290
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
291
+
292
+ if self.add_temp_upsample:
293
+ hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size)
294
+
295
+ if self.upsamplers is not None:
296
+ for upsampler in self.upsamplers:
297
+ hidden_states = upsampler(hidden_states)
298
+
299
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
300
+ return hidden_states
301
+
302
+
303
+ class AllegroMidBlock3DConv(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channels: int,
307
+ temb_channels: int,
308
+ dropout: float = 0.0,
309
+ num_layers: int = 1,
310
+ resnet_eps: float = 1e-6,
311
+ resnet_time_scale_shift: str = "default", # default, spatial
312
+ resnet_act_fn: str = "swish",
313
+ resnet_groups: int = 32,
314
+ resnet_pre_norm: bool = True,
315
+ add_attention: bool = True,
316
+ attention_head_dim: int = 1,
317
+ output_scale_factor: float = 1.0,
318
+ ):
319
+ super().__init__()
320
+
321
+ # there is always at least one resnet
322
+ resnets = [
323
+ ResnetBlock2D(
324
+ in_channels=in_channels,
325
+ out_channels=in_channels,
326
+ temb_channels=temb_channels,
327
+ eps=resnet_eps,
328
+ groups=resnet_groups,
329
+ dropout=dropout,
330
+ time_embedding_norm=resnet_time_scale_shift,
331
+ non_linearity=resnet_act_fn,
332
+ output_scale_factor=output_scale_factor,
333
+ pre_norm=resnet_pre_norm,
334
+ )
335
+ ]
336
+ temp_convs = [
337
+ AllegroTemporalConvLayer(
338
+ in_channels,
339
+ in_channels,
340
+ dropout=0.1,
341
+ norm_num_groups=resnet_groups,
342
+ )
343
+ ]
344
+ attentions = []
345
+
346
+ if attention_head_dim is None:
347
+ attention_head_dim = in_channels
348
+
349
+ for _ in range(num_layers):
350
+ if add_attention:
351
+ attentions.append(
352
+ Attention(
353
+ in_channels,
354
+ heads=in_channels // attention_head_dim,
355
+ dim_head=attention_head_dim,
356
+ rescale_output_factor=output_scale_factor,
357
+ eps=resnet_eps,
358
+ norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
359
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
360
+ residual_connection=True,
361
+ bias=True,
362
+ upcast_softmax=True,
363
+ _from_deprecated_attn_block=True,
364
+ )
365
+ )
366
+ else:
367
+ attentions.append(None)
368
+
369
+ resnets.append(
370
+ ResnetBlock2D(
371
+ in_channels=in_channels,
372
+ out_channels=in_channels,
373
+ temb_channels=temb_channels,
374
+ eps=resnet_eps,
375
+ groups=resnet_groups,
376
+ dropout=dropout,
377
+ time_embedding_norm=resnet_time_scale_shift,
378
+ non_linearity=resnet_act_fn,
379
+ output_scale_factor=output_scale_factor,
380
+ pre_norm=resnet_pre_norm,
381
+ )
382
+ )
383
+
384
+ temp_convs.append(
385
+ AllegroTemporalConvLayer(
386
+ in_channels,
387
+ in_channels,
388
+ dropout=0.1,
389
+ norm_num_groups=resnet_groups,
390
+ )
391
+ )
392
+
393
+ self.resnets = nn.ModuleList(resnets)
394
+ self.temp_convs = nn.ModuleList(temp_convs)
395
+ self.attentions = nn.ModuleList(attentions)
396
+
397
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
398
+ batch_size = hidden_states.shape[0]
399
+
400
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
401
+ hidden_states = self.resnets[0](hidden_states, temb=None)
402
+
403
+ hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size)
404
+
405
+ for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]):
406
+ hidden_states = attn(hidden_states)
407
+ hidden_states = resnet(hidden_states, temb=None)
408
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
409
+
410
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
411
+ return hidden_states
412
+
413
+
414
+ class AllegroEncoder3D(nn.Module):
415
+ def __init__(
416
+ self,
417
+ in_channels: int = 3,
418
+ out_channels: int = 3,
419
+ down_block_types: Tuple[str, ...] = (
420
+ "AllegroDownBlock3D",
421
+ "AllegroDownBlock3D",
422
+ "AllegroDownBlock3D",
423
+ "AllegroDownBlock3D",
424
+ ),
425
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
426
+ temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False],
427
+ layers_per_block: int = 2,
428
+ norm_num_groups: int = 32,
429
+ act_fn: str = "silu",
430
+ double_z: bool = True,
431
+ ):
432
+ super().__init__()
433
+
434
+ self.conv_in = nn.Conv2d(
435
+ in_channels,
436
+ block_out_channels[0],
437
+ kernel_size=3,
438
+ stride=1,
439
+ padding=1,
440
+ )
441
+
442
+ self.temp_conv_in = nn.Conv3d(
443
+ in_channels=block_out_channels[0],
444
+ out_channels=block_out_channels[0],
445
+ kernel_size=(3, 1, 1),
446
+ padding=(1, 0, 0),
447
+ )
448
+
449
+ self.down_blocks = nn.ModuleList([])
450
+
451
+ # down
452
+ output_channel = block_out_channels[0]
453
+ for i, down_block_type in enumerate(down_block_types):
454
+ input_channel = output_channel
455
+ output_channel = block_out_channels[i]
456
+ is_final_block = i == len(block_out_channels) - 1
457
+
458
+ if down_block_type == "AllegroDownBlock3D":
459
+ down_block = AllegroDownBlock3D(
460
+ num_layers=layers_per_block,
461
+ in_channels=input_channel,
462
+ out_channels=output_channel,
463
+ spatial_downsample=not is_final_block,
464
+ temporal_downsample=temporal_downsample_blocks[i],
465
+ resnet_eps=1e-6,
466
+ downsample_padding=0,
467
+ resnet_act_fn=act_fn,
468
+ resnet_groups=norm_num_groups,
469
+ )
470
+ else:
471
+ raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`")
472
+
473
+ self.down_blocks.append(down_block)
474
+
475
+ # mid
476
+ self.mid_block = AllegroMidBlock3DConv(
477
+ in_channels=block_out_channels[-1],
478
+ resnet_eps=1e-6,
479
+ resnet_act_fn=act_fn,
480
+ output_scale_factor=1,
481
+ resnet_time_scale_shift="default",
482
+ attention_head_dim=block_out_channels[-1],
483
+ resnet_groups=norm_num_groups,
484
+ temb_channels=None,
485
+ )
486
+
487
+ # out
488
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
489
+ self.conv_act = nn.SiLU()
490
+
491
+ conv_out_channels = 2 * out_channels if double_z else out_channels
492
+
493
+ self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
494
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
495
+
496
+ self.gradient_checkpointing = False
497
+
498
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
499
+ batch_size = sample.shape[0]
500
+
501
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
502
+ sample = self.conv_in(sample)
503
+
504
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
505
+ residual = sample
506
+ sample = self.temp_conv_in(sample)
507
+ sample = sample + residual
508
+
509
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
510
+
511
+ def create_custom_forward(module):
512
+ def custom_forward(*inputs):
513
+ return module(*inputs)
514
+
515
+ return custom_forward
516
+
517
+ # Down blocks
518
+ for down_block in self.down_blocks:
519
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
520
+
521
+ # Mid block
522
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
523
+ else:
524
+ # Down blocks
525
+ for down_block in self.down_blocks:
526
+ sample = down_block(sample)
527
+
528
+ # Mid block
529
+ sample = self.mid_block(sample)
530
+
531
+ # Post process
532
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
533
+ sample = self.conv_norm_out(sample)
534
+ sample = self.conv_act(sample)
535
+
536
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
537
+ residual = sample
538
+ sample = self.temp_conv_out(sample)
539
+ sample = sample + residual
540
+
541
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
542
+ sample = self.conv_out(sample)
543
+
544
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
545
+ return sample
546
+
547
+
548
+ class AllegroDecoder3D(nn.Module):
549
+ def __init__(
550
+ self,
551
+ in_channels: int = 4,
552
+ out_channels: int = 3,
553
+ up_block_types: Tuple[str, ...] = (
554
+ "AllegroUpBlock3D",
555
+ "AllegroUpBlock3D",
556
+ "AllegroUpBlock3D",
557
+ "AllegroUpBlock3D",
558
+ ),
559
+ temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False],
560
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
561
+ layers_per_block: int = 2,
562
+ norm_num_groups: int = 32,
563
+ act_fn: str = "silu",
564
+ norm_type: str = "group", # group, spatial
565
+ ):
566
+ super().__init__()
567
+
568
+ self.conv_in = nn.Conv2d(
569
+ in_channels,
570
+ block_out_channels[-1],
571
+ kernel_size=3,
572
+ stride=1,
573
+ padding=1,
574
+ )
575
+
576
+ self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
577
+
578
+ self.mid_block = None
579
+ self.up_blocks = nn.ModuleList([])
580
+
581
+ temb_channels = in_channels if norm_type == "spatial" else None
582
+
583
+ # mid
584
+ self.mid_block = AllegroMidBlock3DConv(
585
+ in_channels=block_out_channels[-1],
586
+ resnet_eps=1e-6,
587
+ resnet_act_fn=act_fn,
588
+ output_scale_factor=1,
589
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
590
+ attention_head_dim=block_out_channels[-1],
591
+ resnet_groups=norm_num_groups,
592
+ temb_channels=temb_channels,
593
+ )
594
+
595
+ # up
596
+ reversed_block_out_channels = list(reversed(block_out_channels))
597
+ output_channel = reversed_block_out_channels[0]
598
+ for i, up_block_type in enumerate(up_block_types):
599
+ prev_output_channel = output_channel
600
+ output_channel = reversed_block_out_channels[i]
601
+
602
+ is_final_block = i == len(block_out_channels) - 1
603
+
604
+ if up_block_type == "AllegroUpBlock3D":
605
+ up_block = AllegroUpBlock3D(
606
+ num_layers=layers_per_block + 1,
607
+ in_channels=prev_output_channel,
608
+ out_channels=output_channel,
609
+ spatial_upsample=not is_final_block,
610
+ temporal_upsample=temporal_upsample_blocks[i],
611
+ resnet_eps=1e-6,
612
+ resnet_act_fn=act_fn,
613
+ resnet_groups=norm_num_groups,
614
+ temb_channels=temb_channels,
615
+ resnet_time_scale_shift=norm_type,
616
+ )
617
+ else:
618
+ raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`")
619
+
620
+ self.up_blocks.append(up_block)
621
+ prev_output_channel = output_channel
622
+
623
+ # out
624
+ if norm_type == "spatial":
625
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
626
+ else:
627
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
628
+
629
+ self.conv_act = nn.SiLU()
630
+
631
+ self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))
632
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
633
+
634
+ self.gradient_checkpointing = False
635
+
636
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
637
+ batch_size = sample.shape[0]
638
+
639
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
640
+ sample = self.conv_in(sample)
641
+
642
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
643
+ residual = sample
644
+ sample = self.temp_conv_in(sample)
645
+ sample = sample + residual
646
+
647
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
648
+
649
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
650
+
651
+ def create_custom_forward(module):
652
+ def custom_forward(*inputs):
653
+ return module(*inputs)
654
+
655
+ return custom_forward
656
+
657
+ # Mid block
658
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
659
+
660
+ # Up blocks
661
+ for up_block in self.up_blocks:
662
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
663
+
664
+ else:
665
+ # Mid block
666
+ sample = self.mid_block(sample)
667
+ sample = sample.to(upscale_dtype)
668
+
669
+ # Up blocks
670
+ for up_block in self.up_blocks:
671
+ sample = up_block(sample)
672
+
673
+ # Post process
674
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
675
+ sample = self.conv_norm_out(sample)
676
+ sample = self.conv_act(sample)
677
+
678
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
679
+ residual = sample
680
+ sample = self.temp_conv_out(sample)
681
+ sample = sample + residual
682
+
683
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
684
+ sample = self.conv_out(sample)
685
+
686
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
687
+ return sample
688
+
689
+
690
+ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
691
+ r"""
692
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
693
+ [Allegro](https://github.com/rhymes-ai/Allegro).
694
+
695
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
696
+ for all models (such as downloading or saving).
697
+
698
+ Parameters:
699
+ in_channels (int, defaults to `3`):
700
+ Number of channels in the input image.
701
+ out_channels (int, defaults to `3`):
702
+ Number of channels in the output.
703
+ down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
704
+ Tuple of strings denoting which types of down blocks to use.
705
+ up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
706
+ Tuple of strings denoting which types of up blocks to use.
707
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
708
+ Tuple of integers denoting number of output channels in each block.
709
+ temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
710
+ Tuple of booleans denoting which blocks to enable temporal downsampling in.
711
+ latent_channels (`int`, defaults to `4`):
712
+ Number of channels in latents.
713
+ layers_per_block (`int`, defaults to `2`):
714
+ Number of resnet or attention or temporal convolution layers per down/up block.
715
+ act_fn (`str`, defaults to `"silu"`):
716
+ The activation function to use.
717
+ norm_num_groups (`int`, defaults to `32`):
718
+ Number of groups to use in normalization layers.
719
+ temporal_compression_ratio (`int`, defaults to `4`):
720
+ Ratio by which temporal dimension of samples are compressed.
721
+ sample_size (`int`, defaults to `320`):
722
+ Default latent size.
723
+ scaling_factor (`float`, defaults to `0.13235`):
724
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
725
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
726
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
727
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
728
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
729
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
730
+ force_upcast (`bool`, default to `True`):
731
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
732
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
733
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
734
+ """
735
+
736
+ _supports_gradient_checkpointing = True
737
+
738
+ @register_to_config
739
+ def __init__(
740
+ self,
741
+ in_channels: int = 3,
742
+ out_channels: int = 3,
743
+ down_block_types: Tuple[str, ...] = (
744
+ "AllegroDownBlock3D",
745
+ "AllegroDownBlock3D",
746
+ "AllegroDownBlock3D",
747
+ "AllegroDownBlock3D",
748
+ ),
749
+ up_block_types: Tuple[str, ...] = (
750
+ "AllegroUpBlock3D",
751
+ "AllegroUpBlock3D",
752
+ "AllegroUpBlock3D",
753
+ "AllegroUpBlock3D",
754
+ ),
755
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
756
+ temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
757
+ temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
758
+ latent_channels: int = 4,
759
+ layers_per_block: int = 2,
760
+ act_fn: str = "silu",
761
+ norm_num_groups: int = 32,
762
+ temporal_compression_ratio: float = 4,
763
+ sample_size: int = 320,
764
+ scaling_factor: float = 0.13,
765
+ force_upcast: bool = True,
766
+ ) -> None:
767
+ super().__init__()
768
+
769
+ self.encoder = AllegroEncoder3D(
770
+ in_channels=in_channels,
771
+ out_channels=latent_channels,
772
+ down_block_types=down_block_types,
773
+ temporal_downsample_blocks=temporal_downsample_blocks,
774
+ block_out_channels=block_out_channels,
775
+ layers_per_block=layers_per_block,
776
+ act_fn=act_fn,
777
+ norm_num_groups=norm_num_groups,
778
+ double_z=True,
779
+ )
780
+ self.decoder = AllegroDecoder3D(
781
+ in_channels=latent_channels,
782
+ out_channels=out_channels,
783
+ up_block_types=up_block_types,
784
+ temporal_upsample_blocks=temporal_upsample_blocks,
785
+ block_out_channels=block_out_channels,
786
+ layers_per_block=layers_per_block,
787
+ norm_num_groups=norm_num_groups,
788
+ act_fn=act_fn,
789
+ )
790
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
791
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
792
+
793
+ # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need
794
+ # to use a specific parameter here or in other VAEs.
795
+
796
+ self.use_slicing = False
797
+ self.use_tiling = False
798
+
799
+ self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
800
+ self.tile_overlap_t = 8
801
+ self.tile_overlap_h = 120
802
+ self.tile_overlap_w = 80
803
+ sample_frames = 24
804
+
805
+ self.kernel = (sample_frames, sample_size, sample_size)
806
+ self.stride = (
807
+ sample_frames - self.tile_overlap_t,
808
+ sample_size - self.tile_overlap_h,
809
+ sample_size - self.tile_overlap_w,
810
+ )
811
+
812
+ def _set_gradient_checkpointing(self, module, value=False):
813
+ if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
814
+ module.gradient_checkpointing = value
815
+
816
+ def enable_tiling(self) -> None:
817
+ r"""
818
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
819
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
820
+ processing larger images.
821
+ """
822
+ self.use_tiling = True
823
+
824
+ def disable_tiling(self) -> None:
825
+ r"""
826
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
827
+ decoding in one step.
828
+ """
829
+ self.use_tiling = False
830
+
831
+ def enable_slicing(self) -> None:
832
+ r"""
833
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
834
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
835
+ """
836
+ self.use_slicing = True
837
+
838
+ def disable_slicing(self) -> None:
839
+ r"""
840
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
841
+ decoding in one step.
842
+ """
843
+ self.use_slicing = False
844
+
845
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
846
+ # TODO(aryan)
847
+ # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
848
+ if self.use_tiling:
849
+ return self.tiled_encode(x)
850
+
851
+ raise NotImplementedError("Encoding without tiling has not been implemented yet.")
852
+
853
+ @apply_forward_hook
854
+ def encode(
855
+ self, x: torch.Tensor, return_dict: bool = True
856
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
857
+ r"""
858
+ Encode a batch of videos into latents.
859
+
860
+ Args:
861
+ x (`torch.Tensor`):
862
+ Input batch of videos.
863
+ return_dict (`bool`, defaults to `True`):
864
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
865
+
866
+ Returns:
867
+ The latent representations of the encoded videos. If `return_dict` is True, a
868
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
869
+ """
870
+ if self.use_slicing and x.shape[0] > 1:
871
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
872
+ h = torch.cat(encoded_slices)
873
+ else:
874
+ h = self._encode(x)
875
+
876
+ posterior = DiagonalGaussianDistribution(h)
877
+
878
+ if not return_dict:
879
+ return (posterior,)
880
+ return AutoencoderKLOutput(latent_dist=posterior)
881
+
882
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
883
+ # TODO(aryan): refactor tiling implementation
884
+ # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
885
+ if self.use_tiling:
886
+ return self.tiled_decode(z)
887
+
888
+ raise NotImplementedError("Decoding without tiling has not been implemented yet.")
889
+
890
+ @apply_forward_hook
891
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
892
+ """
893
+ Decode a batch of videos.
894
+
895
+ Args:
896
+ z (`torch.Tensor`):
897
+ Input batch of latent vectors.
898
+ return_dict (`bool`, defaults to `True`):
899
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
900
+
901
+ Returns:
902
+ [`~models.vae.DecoderOutput`] or `tuple`:
903
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
904
+ returned.
905
+ """
906
+ if self.use_slicing and z.shape[0] > 1:
907
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
908
+ decoded = torch.cat(decoded_slices)
909
+ else:
910
+ decoded = self._decode(z)
911
+
912
+ if not return_dict:
913
+ return (decoded,)
914
+ return DecoderOutput(sample=decoded)
915
+
916
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
917
+ local_batch_size = 1
918
+ rs = self.spatial_compression_ratio
919
+ rt = self.config.temporal_compression_ratio
920
+
921
+ batch_size, num_channels, num_frames, height, width = x.shape
922
+
923
+ output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1
924
+ output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1
925
+ output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1
926
+
927
+ count = 0
928
+ output_latent = x.new_zeros(
929
+ (
930
+ output_num_frames * output_height * output_width,
931
+ 2 * self.config.latent_channels,
932
+ self.kernel[0] // rt,
933
+ self.kernel[1] // rs,
934
+ self.kernel[2] // rs,
935
+ )
936
+ )
937
+ vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]))
938
+
939
+ for i in range(output_num_frames):
940
+ for j in range(output_height):
941
+ for k in range(output_width):
942
+ n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
943
+ h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
944
+ w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
945
+
946
+ video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
947
+ vae_batch_input[count % local_batch_size] = video_cube
948
+
949
+ if (
950
+ count % local_batch_size == local_batch_size - 1
951
+ or count == output_num_frames * output_height * output_width - 1
952
+ ):
953
+ latent = self.encoder(vae_batch_input)
954
+
955
+ if (
956
+ count == output_num_frames * output_height * output_width - 1
957
+ and count % local_batch_size != local_batch_size - 1
958
+ ):
959
+ output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1]
960
+ else:
961
+ output_latent[count - local_batch_size + 1 : count + 1] = latent
962
+
963
+ vae_batch_input = x.new_zeros(
964
+ (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])
965
+ )
966
+
967
+ count += 1
968
+
969
+ latent = x.new_zeros(
970
+ (batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs)
971
+ )
972
+ output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
973
+ output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
974
+ output_overlap = (
975
+ output_kernel[0] - output_stride[0],
976
+ output_kernel[1] - output_stride[1],
977
+ output_kernel[2] - output_stride[2],
978
+ )
979
+
980
+ for i in range(output_num_frames):
981
+ n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0]
982
+ for j in range(output_height):
983
+ h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1]
984
+ for k in range(output_width):
985
+ w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2]
986
+ latent_mean = _prepare_for_blend(
987
+ (i, output_num_frames, output_overlap[0]),
988
+ (j, output_height, output_overlap[1]),
989
+ (k, output_width, output_overlap[2]),
990
+ output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0),
991
+ )
992
+ latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean
993
+
994
+ latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1)
995
+ latent = self.quant_conv(latent)
996
+ latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
997
+ return latent
998
+
999
+ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
1000
+ local_batch_size = 1
1001
+ rs = self.spatial_compression_ratio
1002
+ rt = self.config.temporal_compression_ratio
1003
+
1004
+ latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
1005
+ latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
1006
+
1007
+ batch_size, num_channels, num_frames, height, width = z.shape
1008
+
1009
+ ## post quant conv (a mapping)
1010
+ z = z.permute(0, 2, 1, 3, 4).flatten(0, 1)
1011
+ z = self.post_quant_conv(z)
1012
+ z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
1013
+
1014
+ output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1
1015
+ output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1
1016
+ output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1
1017
+
1018
+ count = 0
1019
+ decoded_videos = z.new_zeros(
1020
+ (
1021
+ output_num_frames * output_height * output_width,
1022
+ self.config.out_channels,
1023
+ self.kernel[0],
1024
+ self.kernel[1],
1025
+ self.kernel[2],
1026
+ )
1027
+ )
1028
+ vae_batch_input = z.new_zeros(
1029
+ (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
1030
+ )
1031
+
1032
+ for i in range(output_num_frames):
1033
+ for j in range(output_height):
1034
+ for k in range(output_width):
1035
+ n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0]
1036
+ h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1]
1037
+ w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2]
1038
+
1039
+ current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
1040
+ vae_batch_input[count % local_batch_size] = current_latent
1041
+
1042
+ if (
1043
+ count % local_batch_size == local_batch_size - 1
1044
+ or count == output_num_frames * output_height * output_width - 1
1045
+ ):
1046
+ current_video = self.decoder(vae_batch_input)
1047
+
1048
+ if (
1049
+ count == output_num_frames * output_height * output_width - 1
1050
+ and count % local_batch_size != local_batch_size - 1
1051
+ ):
1052
+ decoded_videos[count - count % local_batch_size :] = current_video[
1053
+ : count % local_batch_size + 1
1054
+ ]
1055
+ else:
1056
+ decoded_videos[count - local_batch_size + 1 : count + 1] = current_video
1057
+
1058
+ vae_batch_input = z.new_zeros(
1059
+ (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
1060
+ )
1061
+
1062
+ count += 1
1063
+
1064
+ video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs))
1065
+ video_overlap = (
1066
+ self.kernel[0] - self.stride[0],
1067
+ self.kernel[1] - self.stride[1],
1068
+ self.kernel[2] - self.stride[2],
1069
+ )
1070
+
1071
+ for i in range(output_num_frames):
1072
+ n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
1073
+ for j in range(output_height):
1074
+ h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
1075
+ for k in range(output_width):
1076
+ w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
1077
+ out_video_blend = _prepare_for_blend(
1078
+ (i, output_num_frames, video_overlap[0]),
1079
+ (j, output_height, video_overlap[1]),
1080
+ (k, output_width, video_overlap[2]),
1081
+ decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0),
1082
+ )
1083
+ video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
1084
+
1085
+ video = video.permute(0, 2, 1, 3, 4).contiguous()
1086
+ return video
1087
+
1088
+ def forward(
1089
+ self,
1090
+ sample: torch.Tensor,
1091
+ sample_posterior: bool = False,
1092
+ return_dict: bool = True,
1093
+ generator: Optional[torch.Generator] = None,
1094
+ ) -> Union[DecoderOutput, torch.Tensor]:
1095
+ r"""
1096
+ Args:
1097
+ sample (`torch.Tensor`): Input sample.
1098
+ sample_posterior (`bool`, *optional*, defaults to `False`):
1099
+ Whether to sample from the posterior.
1100
+ return_dict (`bool`, *optional*, defaults to `True`):
1101
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1102
+ generator (`torch.Generator`, *optional*):
1103
+ PyTorch random number generator.
1104
+ """
1105
+ x = sample
1106
+ posterior = self.encode(x).latent_dist
1107
+ if sample_posterior:
1108
+ z = posterior.sample(generator=generator)
1109
+ else:
1110
+ z = posterior.mode()
1111
+ dec = self.decode(z).sample
1112
+
1113
+ if not return_dict:
1114
+ return (dec,)
1115
+
1116
+ return DecoderOutput(sample=dec)
1117
+
1118
+
1119
+ def _prepare_for_blend(n_param, h_param, w_param, x):
1120
+ # TODO(aryan): refactor
1121
+ n, n_max, overlap_n = n_param
1122
+ h, h_max, overlap_h = h_param
1123
+ w, w_max, overlap_w = w_param
1124
+ if overlap_n > 0:
1125
+ if n > 0: # the head overlap part decays from 0 to 1
1126
+ x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * (
1127
+ torch.arange(0, overlap_n).float().to(x.device) / overlap_n
1128
+ ).reshape(overlap_n, 1, 1)
1129
+ if n < n_max - 1: # the tail overlap part decays from 1 to 0
1130
+ x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * (
1131
+ 1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n
1132
+ ).reshape(overlap_n, 1, 1)
1133
+ if h > 0:
1134
+ x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * (
1135
+ torch.arange(0, overlap_h).float().to(x.device) / overlap_h
1136
+ ).reshape(overlap_h, 1)
1137
+ if h < h_max - 1:
1138
+ x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * (
1139
+ 1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h
1140
+ ).reshape(overlap_h, 1)
1141
+ if w > 0:
1142
+ x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * (
1143
+ torch.arange(0, overlap_w).float().to(x.device) / overlap_w
1144
+ )
1145
+ if w < w_max - 1:
1146
+ x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * (
1147
+ 1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w
1148
+ )
1149
+ return x
icedit/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py ADDED
@@ -0,0 +1,1482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...loaders.single_file_model import FromOriginalModelMixin
25
+ from ...utils import logging
26
+ from ...utils.accelerate_utils import apply_forward_hook
27
+ from ..activations import get_activation
28
+ from ..downsampling import CogVideoXDownsample3D
29
+ from ..modeling_outputs import AutoencoderKLOutput
30
+ from ..modeling_utils import ModelMixin
31
+ from ..upsampling import CogVideoXUpsample3D
32
+ from .vae import DecoderOutput, DiagonalGaussianDistribution
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class CogVideoXSafeConv3d(nn.Conv3d):
39
+ r"""
40
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
41
+ """
42
+
43
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
44
+ memory_count = (
45
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
46
+ )
47
+
48
+ # Set to 2GB, suitable for CuDNN
49
+ if memory_count > 2:
50
+ kernel_size = self.kernel_size[0]
51
+ part_num = int(memory_count / 2) + 1
52
+ input_chunks = torch.chunk(input, part_num, dim=2)
53
+
54
+ if kernel_size > 1:
55
+ input_chunks = [input_chunks[0]] + [
56
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
57
+ for i in range(1, len(input_chunks))
58
+ ]
59
+
60
+ output_chunks = []
61
+ for input_chunk in input_chunks:
62
+ output_chunks.append(super().forward(input_chunk))
63
+ output = torch.cat(output_chunks, dim=2)
64
+ return output
65
+ else:
66
+ return super().forward(input)
67
+
68
+
69
+ class CogVideoXCausalConv3d(nn.Module):
70
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
71
+
72
+ Args:
73
+ in_channels (`int`): Number of channels in the input tensor.
74
+ out_channels (`int`): Number of output channels produced by the convolution.
75
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
76
+ stride (`int`, defaults to `1`): Stride of the convolution.
77
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
78
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ in_channels: int,
84
+ out_channels: int,
85
+ kernel_size: Union[int, Tuple[int, int, int]],
86
+ stride: int = 1,
87
+ dilation: int = 1,
88
+ pad_mode: str = "constant",
89
+ ):
90
+ super().__init__()
91
+
92
+ if isinstance(kernel_size, int):
93
+ kernel_size = (kernel_size,) * 3
94
+
95
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
96
+
97
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
98
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
99
+ time_pad = time_kernel_size - 1
100
+ height_pad = (height_kernel_size - 1) // 2
101
+ width_pad = (width_kernel_size - 1) // 2
102
+
103
+ self.pad_mode = pad_mode
104
+ self.height_pad = height_pad
105
+ self.width_pad = width_pad
106
+ self.time_pad = time_pad
107
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
108
+
109
+ self.temporal_dim = 2
110
+ self.time_kernel_size = time_kernel_size
111
+
112
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
113
+ dilation = (dilation, 1, 1)
114
+ self.conv = CogVideoXSafeConv3d(
115
+ in_channels=in_channels,
116
+ out_channels=out_channels,
117
+ kernel_size=kernel_size,
118
+ stride=stride,
119
+ dilation=dilation,
120
+ )
121
+
122
+ def fake_context_parallel_forward(
123
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
124
+ ) -> torch.Tensor:
125
+ if self.pad_mode == "replicate":
126
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
127
+ else:
128
+ kernel_size = self.time_kernel_size
129
+ if kernel_size > 1:
130
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
131
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
132
+ return inputs
133
+
134
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
135
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
136
+
137
+ if self.pad_mode == "replicate":
138
+ conv_cache = None
139
+ else:
140
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
141
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
142
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143
+
144
+ output = self.conv(inputs)
145
+ return output, conv_cache
146
+
147
+
148
+ class CogVideoXSpatialNorm3D(nn.Module):
149
+ r"""
150
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
151
+ to 3D-video like data.
152
+
153
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
154
+
155
+ Args:
156
+ f_channels (`int`):
157
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
158
+ zq_channels (`int`):
159
+ The number of channels for the quantized vector as described in the paper.
160
+ groups (`int`):
161
+ Number of groups to separate the channels into for group normalization.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ f_channels: int,
167
+ zq_channels: int,
168
+ groups: int = 32,
169
+ ):
170
+ super().__init__()
171
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
172
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
173
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
174
+
175
+ def forward(
176
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
177
+ ) -> torch.Tensor:
178
+ new_conv_cache = {}
179
+ conv_cache = conv_cache or {}
180
+
181
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
182
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
183
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
184
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
185
+ z_first = F.interpolate(z_first, size=f_first_size)
186
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
187
+ zq = torch.cat([z_first, z_rest], dim=2)
188
+ else:
189
+ zq = F.interpolate(zq, size=f.shape[-3:])
190
+
191
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
192
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
193
+
194
+ norm_f = self.norm_layer(f)
195
+ new_f = norm_f * conv_y + conv_b
196
+ return new_f, new_conv_cache
197
+
198
+
199
+ class CogVideoXResnetBlock3D(nn.Module):
200
+ r"""
201
+ A 3D ResNet block used in the CogVideoX model.
202
+
203
+ Args:
204
+ in_channels (`int`):
205
+ Number of input channels.
206
+ out_channels (`int`, *optional*):
207
+ Number of output channels. If None, defaults to `in_channels`.
208
+ dropout (`float`, defaults to `0.0`):
209
+ Dropout rate.
210
+ temb_channels (`int`, defaults to `512`):
211
+ Number of time embedding channels.
212
+ groups (`int`, defaults to `32`):
213
+ Number of groups to separate the channels into for group normalization.
214
+ eps (`float`, defaults to `1e-6`):
215
+ Epsilon value for normalization layers.
216
+ non_linearity (`str`, defaults to `"swish"`):
217
+ Activation function to use.
218
+ conv_shortcut (bool, defaults to `False`):
219
+ Whether or not to use a convolution shortcut.
220
+ spatial_norm_dim (`int`, *optional*):
221
+ The dimension to use for spatial norm if it is to be used instead of group norm.
222
+ pad_mode (str, defaults to `"first"`):
223
+ Padding mode.
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ in_channels: int,
229
+ out_channels: Optional[int] = None,
230
+ dropout: float = 0.0,
231
+ temb_channels: int = 512,
232
+ groups: int = 32,
233
+ eps: float = 1e-6,
234
+ non_linearity: str = "swish",
235
+ conv_shortcut: bool = False,
236
+ spatial_norm_dim: Optional[int] = None,
237
+ pad_mode: str = "first",
238
+ ):
239
+ super().__init__()
240
+
241
+ out_channels = out_channels or in_channels
242
+
243
+ self.in_channels = in_channels
244
+ self.out_channels = out_channels
245
+ self.nonlinearity = get_activation(non_linearity)
246
+ self.use_conv_shortcut = conv_shortcut
247
+ self.spatial_norm_dim = spatial_norm_dim
248
+
249
+ if spatial_norm_dim is None:
250
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
251
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
252
+ else:
253
+ self.norm1 = CogVideoXSpatialNorm3D(
254
+ f_channels=in_channels,
255
+ zq_channels=spatial_norm_dim,
256
+ groups=groups,
257
+ )
258
+ self.norm2 = CogVideoXSpatialNorm3D(
259
+ f_channels=out_channels,
260
+ zq_channels=spatial_norm_dim,
261
+ groups=groups,
262
+ )
263
+
264
+ self.conv1 = CogVideoXCausalConv3d(
265
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
266
+ )
267
+
268
+ if temb_channels > 0:
269
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
270
+
271
+ self.dropout = nn.Dropout(dropout)
272
+ self.conv2 = CogVideoXCausalConv3d(
273
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
274
+ )
275
+
276
+ if self.in_channels != self.out_channels:
277
+ if self.use_conv_shortcut:
278
+ self.conv_shortcut = CogVideoXCausalConv3d(
279
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
280
+ )
281
+ else:
282
+ self.conv_shortcut = CogVideoXSafeConv3d(
283
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
284
+ )
285
+
286
+ def forward(
287
+ self,
288
+ inputs: torch.Tensor,
289
+ temb: Optional[torch.Tensor] = None,
290
+ zq: Optional[torch.Tensor] = None,
291
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
292
+ ) -> torch.Tensor:
293
+ new_conv_cache = {}
294
+ conv_cache = conv_cache or {}
295
+
296
+ hidden_states = inputs
297
+
298
+ if zq is not None:
299
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
300
+ else:
301
+ hidden_states = self.norm1(hidden_states)
302
+
303
+ hidden_states = self.nonlinearity(hidden_states)
304
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
305
+
306
+ if temb is not None:
307
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
308
+
309
+ if zq is not None:
310
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
311
+ else:
312
+ hidden_states = self.norm2(hidden_states)
313
+
314
+ hidden_states = self.nonlinearity(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
317
+
318
+ if self.in_channels != self.out_channels:
319
+ if self.use_conv_shortcut:
320
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
321
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
322
+ )
323
+ else:
324
+ inputs = self.conv_shortcut(inputs)
325
+
326
+ hidden_states = hidden_states + inputs
327
+ return hidden_states, new_conv_cache
328
+
329
+
330
+ class CogVideoXDownBlock3D(nn.Module):
331
+ r"""
332
+ A downsampling block used in the CogVideoX model.
333
+
334
+ Args:
335
+ in_channels (`int`):
336
+ Number of input channels.
337
+ out_channels (`int`, *optional*):
338
+ Number of output channels. If None, defaults to `in_channels`.
339
+ temb_channels (`int`, defaults to `512`):
340
+ Number of time embedding channels.
341
+ num_layers (`int`, defaults to `1`):
342
+ Number of resnet layers.
343
+ dropout (`float`, defaults to `0.0`):
344
+ Dropout rate.
345
+ resnet_eps (`float`, defaults to `1e-6`):
346
+ Epsilon value for normalization layers.
347
+ resnet_act_fn (`str`, defaults to `"swish"`):
348
+ Activation function to use.
349
+ resnet_groups (`int`, defaults to `32`):
350
+ Number of groups to separate the channels into for group normalization.
351
+ add_downsample (`bool`, defaults to `True`):
352
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
353
+ compress_time (`bool`, defaults to `False`):
354
+ Whether or not to downsample across temporal dimension.
355
+ pad_mode (str, defaults to `"first"`):
356
+ Padding mode.
357
+ """
358
+
359
+ _supports_gradient_checkpointing = True
360
+
361
+ def __init__(
362
+ self,
363
+ in_channels: int,
364
+ out_channels: int,
365
+ temb_channels: int,
366
+ dropout: float = 0.0,
367
+ num_layers: int = 1,
368
+ resnet_eps: float = 1e-6,
369
+ resnet_act_fn: str = "swish",
370
+ resnet_groups: int = 32,
371
+ add_downsample: bool = True,
372
+ downsample_padding: int = 0,
373
+ compress_time: bool = False,
374
+ pad_mode: str = "first",
375
+ ):
376
+ super().__init__()
377
+
378
+ resnets = []
379
+ for i in range(num_layers):
380
+ in_channel = in_channels if i == 0 else out_channels
381
+ resnets.append(
382
+ CogVideoXResnetBlock3D(
383
+ in_channels=in_channel,
384
+ out_channels=out_channels,
385
+ dropout=dropout,
386
+ temb_channels=temb_channels,
387
+ groups=resnet_groups,
388
+ eps=resnet_eps,
389
+ non_linearity=resnet_act_fn,
390
+ pad_mode=pad_mode,
391
+ )
392
+ )
393
+
394
+ self.resnets = nn.ModuleList(resnets)
395
+ self.downsamplers = None
396
+
397
+ if add_downsample:
398
+ self.downsamplers = nn.ModuleList(
399
+ [
400
+ CogVideoXDownsample3D(
401
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
402
+ )
403
+ ]
404
+ )
405
+
406
+ self.gradient_checkpointing = False
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ temb: Optional[torch.Tensor] = None,
412
+ zq: Optional[torch.Tensor] = None,
413
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
414
+ ) -> torch.Tensor:
415
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
416
+
417
+ new_conv_cache = {}
418
+ conv_cache = conv_cache or {}
419
+
420
+ for i, resnet in enumerate(self.resnets):
421
+ conv_cache_key = f"resnet_{i}"
422
+
423
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
424
+
425
+ def create_custom_forward(module):
426
+ def create_forward(*inputs):
427
+ return module(*inputs)
428
+
429
+ return create_forward
430
+
431
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
432
+ create_custom_forward(resnet),
433
+ hidden_states,
434
+ temb,
435
+ zq,
436
+ conv_cache.get(conv_cache_key),
437
+ )
438
+ else:
439
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
440
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
441
+ )
442
+
443
+ if self.downsamplers is not None:
444
+ for downsampler in self.downsamplers:
445
+ hidden_states = downsampler(hidden_states)
446
+
447
+ return hidden_states, new_conv_cache
448
+
449
+
450
+ class CogVideoXMidBlock3D(nn.Module):
451
+ r"""
452
+ A middle block used in the CogVideoX model.
453
+
454
+ Args:
455
+ in_channels (`int`):
456
+ Number of input channels.
457
+ temb_channels (`int`, defaults to `512`):
458
+ Number of time embedding channels.
459
+ dropout (`float`, defaults to `0.0`):
460
+ Dropout rate.
461
+ num_layers (`int`, defaults to `1`):
462
+ Number of resnet layers.
463
+ resnet_eps (`float`, defaults to `1e-6`):
464
+ Epsilon value for normalization layers.
465
+ resnet_act_fn (`str`, defaults to `"swish"`):
466
+ Activation function to use.
467
+ resnet_groups (`int`, defaults to `32`):
468
+ Number of groups to separate the channels into for group normalization.
469
+ spatial_norm_dim (`int`, *optional*):
470
+ The dimension to use for spatial norm if it is to be used instead of group norm.
471
+ pad_mode (str, defaults to `"first"`):
472
+ Padding mode.
473
+ """
474
+
475
+ _supports_gradient_checkpointing = True
476
+
477
+ def __init__(
478
+ self,
479
+ in_channels: int,
480
+ temb_channels: int,
481
+ dropout: float = 0.0,
482
+ num_layers: int = 1,
483
+ resnet_eps: float = 1e-6,
484
+ resnet_act_fn: str = "swish",
485
+ resnet_groups: int = 32,
486
+ spatial_norm_dim: Optional[int] = None,
487
+ pad_mode: str = "first",
488
+ ):
489
+ super().__init__()
490
+
491
+ resnets = []
492
+ for _ in range(num_layers):
493
+ resnets.append(
494
+ CogVideoXResnetBlock3D(
495
+ in_channels=in_channels,
496
+ out_channels=in_channels,
497
+ dropout=dropout,
498
+ temb_channels=temb_channels,
499
+ groups=resnet_groups,
500
+ eps=resnet_eps,
501
+ spatial_norm_dim=spatial_norm_dim,
502
+ non_linearity=resnet_act_fn,
503
+ pad_mode=pad_mode,
504
+ )
505
+ )
506
+ self.resnets = nn.ModuleList(resnets)
507
+
508
+ self.gradient_checkpointing = False
509
+
510
+ def forward(
511
+ self,
512
+ hidden_states: torch.Tensor,
513
+ temb: Optional[torch.Tensor] = None,
514
+ zq: Optional[torch.Tensor] = None,
515
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
516
+ ) -> torch.Tensor:
517
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
518
+
519
+ new_conv_cache = {}
520
+ conv_cache = conv_cache or {}
521
+
522
+ for i, resnet in enumerate(self.resnets):
523
+ conv_cache_key = f"resnet_{i}"
524
+
525
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
526
+
527
+ def create_custom_forward(module):
528
+ def create_forward(*inputs):
529
+ return module(*inputs)
530
+
531
+ return create_forward
532
+
533
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
534
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
535
+ )
536
+ else:
537
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
538
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
539
+ )
540
+
541
+ return hidden_states, new_conv_cache
542
+
543
+
544
+ class CogVideoXUpBlock3D(nn.Module):
545
+ r"""
546
+ An upsampling block used in the CogVideoX model.
547
+
548
+ Args:
549
+ in_channels (`int`):
550
+ Number of input channels.
551
+ out_channels (`int`, *optional*):
552
+ Number of output channels. If None, defaults to `in_channels`.
553
+ temb_channels (`int`, defaults to `512`):
554
+ Number of time embedding channels.
555
+ dropout (`float`, defaults to `0.0`):
556
+ Dropout rate.
557
+ num_layers (`int`, defaults to `1`):
558
+ Number of resnet layers.
559
+ resnet_eps (`float`, defaults to `1e-6`):
560
+ Epsilon value for normalization layers.
561
+ resnet_act_fn (`str`, defaults to `"swish"`):
562
+ Activation function to use.
563
+ resnet_groups (`int`, defaults to `32`):
564
+ Number of groups to separate the channels into for group normalization.
565
+ spatial_norm_dim (`int`, defaults to `16`):
566
+ The dimension to use for spatial norm if it is to be used instead of group norm.
567
+ add_upsample (`bool`, defaults to `True`):
568
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
569
+ compress_time (`bool`, defaults to `False`):
570
+ Whether or not to downsample across temporal dimension.
571
+ pad_mode (str, defaults to `"first"`):
572
+ Padding mode.
573
+ """
574
+
575
+ def __init__(
576
+ self,
577
+ in_channels: int,
578
+ out_channels: int,
579
+ temb_channels: int,
580
+ dropout: float = 0.0,
581
+ num_layers: int = 1,
582
+ resnet_eps: float = 1e-6,
583
+ resnet_act_fn: str = "swish",
584
+ resnet_groups: int = 32,
585
+ spatial_norm_dim: int = 16,
586
+ add_upsample: bool = True,
587
+ upsample_padding: int = 1,
588
+ compress_time: bool = False,
589
+ pad_mode: str = "first",
590
+ ):
591
+ super().__init__()
592
+
593
+ resnets = []
594
+ for i in range(num_layers):
595
+ in_channel = in_channels if i == 0 else out_channels
596
+ resnets.append(
597
+ CogVideoXResnetBlock3D(
598
+ in_channels=in_channel,
599
+ out_channels=out_channels,
600
+ dropout=dropout,
601
+ temb_channels=temb_channels,
602
+ groups=resnet_groups,
603
+ eps=resnet_eps,
604
+ non_linearity=resnet_act_fn,
605
+ spatial_norm_dim=spatial_norm_dim,
606
+ pad_mode=pad_mode,
607
+ )
608
+ )
609
+
610
+ self.resnets = nn.ModuleList(resnets)
611
+ self.upsamplers = None
612
+
613
+ if add_upsample:
614
+ self.upsamplers = nn.ModuleList(
615
+ [
616
+ CogVideoXUpsample3D(
617
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
618
+ )
619
+ ]
620
+ )
621
+
622
+ self.gradient_checkpointing = False
623
+
624
+ def forward(
625
+ self,
626
+ hidden_states: torch.Tensor,
627
+ temb: Optional[torch.Tensor] = None,
628
+ zq: Optional[torch.Tensor] = None,
629
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
630
+ ) -> torch.Tensor:
631
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
632
+
633
+ new_conv_cache = {}
634
+ conv_cache = conv_cache or {}
635
+
636
+ for i, resnet in enumerate(self.resnets):
637
+ conv_cache_key = f"resnet_{i}"
638
+
639
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
640
+
641
+ def create_custom_forward(module):
642
+ def create_forward(*inputs):
643
+ return module(*inputs)
644
+
645
+ return create_forward
646
+
647
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
648
+ create_custom_forward(resnet),
649
+ hidden_states,
650
+ temb,
651
+ zq,
652
+ conv_cache.get(conv_cache_key),
653
+ )
654
+ else:
655
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
656
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
657
+ )
658
+
659
+ if self.upsamplers is not None:
660
+ for upsampler in self.upsamplers:
661
+ hidden_states = upsampler(hidden_states)
662
+
663
+ return hidden_states, new_conv_cache
664
+
665
+
666
+ class CogVideoXEncoder3D(nn.Module):
667
+ r"""
668
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
669
+
670
+ Args:
671
+ in_channels (`int`, *optional*, defaults to 3):
672
+ The number of input channels.
673
+ out_channels (`int`, *optional*, defaults to 3):
674
+ The number of output channels.
675
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
676
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
677
+ options.
678
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
679
+ The number of output channels for each block.
680
+ act_fn (`str`, *optional*, defaults to `"silu"`):
681
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
682
+ layers_per_block (`int`, *optional*, defaults to 2):
683
+ The number of layers per block.
684
+ norm_num_groups (`int`, *optional*, defaults to 32):
685
+ The number of groups for normalization.
686
+ """
687
+
688
+ _supports_gradient_checkpointing = True
689
+
690
+ def __init__(
691
+ self,
692
+ in_channels: int = 3,
693
+ out_channels: int = 16,
694
+ down_block_types: Tuple[str, ...] = (
695
+ "CogVideoXDownBlock3D",
696
+ "CogVideoXDownBlock3D",
697
+ "CogVideoXDownBlock3D",
698
+ "CogVideoXDownBlock3D",
699
+ ),
700
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
701
+ layers_per_block: int = 3,
702
+ act_fn: str = "silu",
703
+ norm_eps: float = 1e-6,
704
+ norm_num_groups: int = 32,
705
+ dropout: float = 0.0,
706
+ pad_mode: str = "first",
707
+ temporal_compression_ratio: float = 4,
708
+ ):
709
+ super().__init__()
710
+
711
+ # log2 of temporal_compress_times
712
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
713
+
714
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
715
+ self.down_blocks = nn.ModuleList([])
716
+
717
+ # down blocks
718
+ output_channel = block_out_channels[0]
719
+ for i, down_block_type in enumerate(down_block_types):
720
+ input_channel = output_channel
721
+ output_channel = block_out_channels[i]
722
+ is_final_block = i == len(block_out_channels) - 1
723
+ compress_time = i < temporal_compress_level
724
+
725
+ if down_block_type == "CogVideoXDownBlock3D":
726
+ down_block = CogVideoXDownBlock3D(
727
+ in_channels=input_channel,
728
+ out_channels=output_channel,
729
+ temb_channels=0,
730
+ dropout=dropout,
731
+ num_layers=layers_per_block,
732
+ resnet_eps=norm_eps,
733
+ resnet_act_fn=act_fn,
734
+ resnet_groups=norm_num_groups,
735
+ add_downsample=not is_final_block,
736
+ compress_time=compress_time,
737
+ )
738
+ else:
739
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
740
+
741
+ self.down_blocks.append(down_block)
742
+
743
+ # mid block
744
+ self.mid_block = CogVideoXMidBlock3D(
745
+ in_channels=block_out_channels[-1],
746
+ temb_channels=0,
747
+ dropout=dropout,
748
+ num_layers=2,
749
+ resnet_eps=norm_eps,
750
+ resnet_act_fn=act_fn,
751
+ resnet_groups=norm_num_groups,
752
+ pad_mode=pad_mode,
753
+ )
754
+
755
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
756
+ self.conv_act = nn.SiLU()
757
+ self.conv_out = CogVideoXCausalConv3d(
758
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
759
+ )
760
+
761
+ self.gradient_checkpointing = False
762
+
763
+ def forward(
764
+ self,
765
+ sample: torch.Tensor,
766
+ temb: Optional[torch.Tensor] = None,
767
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
768
+ ) -> torch.Tensor:
769
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
770
+
771
+ new_conv_cache = {}
772
+ conv_cache = conv_cache or {}
773
+
774
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
775
+
776
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
777
+
778
+ def create_custom_forward(module):
779
+ def custom_forward(*inputs):
780
+ return module(*inputs)
781
+
782
+ return custom_forward
783
+
784
+ # 1. Down
785
+ for i, down_block in enumerate(self.down_blocks):
786
+ conv_cache_key = f"down_block_{i}"
787
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
788
+ create_custom_forward(down_block),
789
+ hidden_states,
790
+ temb,
791
+ None,
792
+ conv_cache.get(conv_cache_key),
793
+ )
794
+
795
+ # 2. Mid
796
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
797
+ create_custom_forward(self.mid_block),
798
+ hidden_states,
799
+ temb,
800
+ None,
801
+ conv_cache.get("mid_block"),
802
+ )
803
+ else:
804
+ # 1. Down
805
+ for i, down_block in enumerate(self.down_blocks):
806
+ conv_cache_key = f"down_block_{i}"
807
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
808
+ hidden_states, temb, None, conv_cache.get(conv_cache_key)
809
+ )
810
+
811
+ # 2. Mid
812
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
813
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
814
+ )
815
+
816
+ # 3. Post-process
817
+ hidden_states = self.norm_out(hidden_states)
818
+ hidden_states = self.conv_act(hidden_states)
819
+
820
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
821
+
822
+ return hidden_states, new_conv_cache
823
+
824
+
825
+ class CogVideoXDecoder3D(nn.Module):
826
+ r"""
827
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
828
+ sample.
829
+
830
+ Args:
831
+ in_channels (`int`, *optional*, defaults to 3):
832
+ The number of input channels.
833
+ out_channels (`int`, *optional*, defaults to 3):
834
+ The number of output channels.
835
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
836
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
837
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
838
+ The number of output channels for each block.
839
+ act_fn (`str`, *optional*, defaults to `"silu"`):
840
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
841
+ layers_per_block (`int`, *optional*, defaults to 2):
842
+ The number of layers per block.
843
+ norm_num_groups (`int`, *optional*, defaults to 32):
844
+ The number of groups for normalization.
845
+ """
846
+
847
+ _supports_gradient_checkpointing = True
848
+
849
+ def __init__(
850
+ self,
851
+ in_channels: int = 16,
852
+ out_channels: int = 3,
853
+ up_block_types: Tuple[str, ...] = (
854
+ "CogVideoXUpBlock3D",
855
+ "CogVideoXUpBlock3D",
856
+ "CogVideoXUpBlock3D",
857
+ "CogVideoXUpBlock3D",
858
+ ),
859
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
860
+ layers_per_block: int = 3,
861
+ act_fn: str = "silu",
862
+ norm_eps: float = 1e-6,
863
+ norm_num_groups: int = 32,
864
+ dropout: float = 0.0,
865
+ pad_mode: str = "first",
866
+ temporal_compression_ratio: float = 4,
867
+ ):
868
+ super().__init__()
869
+
870
+ reversed_block_out_channels = list(reversed(block_out_channels))
871
+
872
+ self.conv_in = CogVideoXCausalConv3d(
873
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
874
+ )
875
+
876
+ # mid block
877
+ self.mid_block = CogVideoXMidBlock3D(
878
+ in_channels=reversed_block_out_channels[0],
879
+ temb_channels=0,
880
+ num_layers=2,
881
+ resnet_eps=norm_eps,
882
+ resnet_act_fn=act_fn,
883
+ resnet_groups=norm_num_groups,
884
+ spatial_norm_dim=in_channels,
885
+ pad_mode=pad_mode,
886
+ )
887
+
888
+ # up blocks
889
+ self.up_blocks = nn.ModuleList([])
890
+
891
+ output_channel = reversed_block_out_channels[0]
892
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
893
+
894
+ for i, up_block_type in enumerate(up_block_types):
895
+ prev_output_channel = output_channel
896
+ output_channel = reversed_block_out_channels[i]
897
+ is_final_block = i == len(block_out_channels) - 1
898
+ compress_time = i < temporal_compress_level
899
+
900
+ if up_block_type == "CogVideoXUpBlock3D":
901
+ up_block = CogVideoXUpBlock3D(
902
+ in_channels=prev_output_channel,
903
+ out_channels=output_channel,
904
+ temb_channels=0,
905
+ dropout=dropout,
906
+ num_layers=layers_per_block + 1,
907
+ resnet_eps=norm_eps,
908
+ resnet_act_fn=act_fn,
909
+ resnet_groups=norm_num_groups,
910
+ spatial_norm_dim=in_channels,
911
+ add_upsample=not is_final_block,
912
+ compress_time=compress_time,
913
+ pad_mode=pad_mode,
914
+ )
915
+ prev_output_channel = output_channel
916
+ else:
917
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
918
+
919
+ self.up_blocks.append(up_block)
920
+
921
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
922
+ self.conv_act = nn.SiLU()
923
+ self.conv_out = CogVideoXCausalConv3d(
924
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
925
+ )
926
+
927
+ self.gradient_checkpointing = False
928
+
929
+ def forward(
930
+ self,
931
+ sample: torch.Tensor,
932
+ temb: Optional[torch.Tensor] = None,
933
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
934
+ ) -> torch.Tensor:
935
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
936
+
937
+ new_conv_cache = {}
938
+ conv_cache = conv_cache or {}
939
+
940
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
941
+
942
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
943
+
944
+ def create_custom_forward(module):
945
+ def custom_forward(*inputs):
946
+ return module(*inputs)
947
+
948
+ return custom_forward
949
+
950
+ # 1. Mid
951
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
952
+ create_custom_forward(self.mid_block),
953
+ hidden_states,
954
+ temb,
955
+ sample,
956
+ conv_cache.get("mid_block"),
957
+ )
958
+
959
+ # 2. Up
960
+ for i, up_block in enumerate(self.up_blocks):
961
+ conv_cache_key = f"up_block_{i}"
962
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
963
+ create_custom_forward(up_block),
964
+ hidden_states,
965
+ temb,
966
+ sample,
967
+ conv_cache.get(conv_cache_key),
968
+ )
969
+ else:
970
+ # 1. Mid
971
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
972
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
973
+ )
974
+
975
+ # 2. Up
976
+ for i, up_block in enumerate(self.up_blocks):
977
+ conv_cache_key = f"up_block_{i}"
978
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
979
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
980
+ )
981
+
982
+ # 3. Post-process
983
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
984
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
985
+ )
986
+ hidden_states = self.conv_act(hidden_states)
987
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
988
+
989
+ return hidden_states, new_conv_cache
990
+
991
+
992
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
993
+ r"""
994
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
995
+ [CogVideoX](https://github.com/THUDM/CogVideo).
996
+
997
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
998
+ for all models (such as downloading or saving).
999
+
1000
+ Parameters:
1001
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1002
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1003
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1004
+ Tuple of downsample block types.
1005
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1006
+ Tuple of upsample block types.
1007
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1008
+ Tuple of block output channels.
1009
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1010
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1011
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
1012
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
1013
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
1014
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
1015
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
1016
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
1017
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
1018
+ force_upcast (`bool`, *optional*, default to `True`):
1019
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1020
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
1021
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1022
+ """
1023
+
1024
+ _supports_gradient_checkpointing = True
1025
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
1026
+
1027
+ @register_to_config
1028
+ def __init__(
1029
+ self,
1030
+ in_channels: int = 3,
1031
+ out_channels: int = 3,
1032
+ down_block_types: Tuple[str] = (
1033
+ "CogVideoXDownBlock3D",
1034
+ "CogVideoXDownBlock3D",
1035
+ "CogVideoXDownBlock3D",
1036
+ "CogVideoXDownBlock3D",
1037
+ ),
1038
+ up_block_types: Tuple[str] = (
1039
+ "CogVideoXUpBlock3D",
1040
+ "CogVideoXUpBlock3D",
1041
+ "CogVideoXUpBlock3D",
1042
+ "CogVideoXUpBlock3D",
1043
+ ),
1044
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
1045
+ latent_channels: int = 16,
1046
+ layers_per_block: int = 3,
1047
+ act_fn: str = "silu",
1048
+ norm_eps: float = 1e-6,
1049
+ norm_num_groups: int = 32,
1050
+ temporal_compression_ratio: float = 4,
1051
+ sample_height: int = 480,
1052
+ sample_width: int = 720,
1053
+ scaling_factor: float = 1.15258426,
1054
+ shift_factor: Optional[float] = None,
1055
+ latents_mean: Optional[Tuple[float]] = None,
1056
+ latents_std: Optional[Tuple[float]] = None,
1057
+ force_upcast: float = True,
1058
+ use_quant_conv: bool = False,
1059
+ use_post_quant_conv: bool = False,
1060
+ invert_scale_latents: bool = False,
1061
+ ):
1062
+ super().__init__()
1063
+
1064
+ self.encoder = CogVideoXEncoder3D(
1065
+ in_channels=in_channels,
1066
+ out_channels=latent_channels,
1067
+ down_block_types=down_block_types,
1068
+ block_out_channels=block_out_channels,
1069
+ layers_per_block=layers_per_block,
1070
+ act_fn=act_fn,
1071
+ norm_eps=norm_eps,
1072
+ norm_num_groups=norm_num_groups,
1073
+ temporal_compression_ratio=temporal_compression_ratio,
1074
+ )
1075
+ self.decoder = CogVideoXDecoder3D(
1076
+ in_channels=latent_channels,
1077
+ out_channels=out_channels,
1078
+ up_block_types=up_block_types,
1079
+ block_out_channels=block_out_channels,
1080
+ layers_per_block=layers_per_block,
1081
+ act_fn=act_fn,
1082
+ norm_eps=norm_eps,
1083
+ norm_num_groups=norm_num_groups,
1084
+ temporal_compression_ratio=temporal_compression_ratio,
1085
+ )
1086
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
1087
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
1088
+
1089
+ self.use_slicing = False
1090
+ self.use_tiling = False
1091
+
1092
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
1093
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
1094
+ # If you decode X latent frames together, the number of output frames is:
1095
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
1096
+ #
1097
+ # Example with num_latent_frames_batch_size = 2:
1098
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
1099
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1100
+ # => 6 * 8 = 48 frames
1101
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
1102
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
1103
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1104
+ # => 1 * 9 + 5 * 8 = 49 frames
1105
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
1106
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1107
+ # number of temporal frames.
1108
+ self.num_latent_frames_batch_size = 2
1109
+ self.num_sample_frames_batch_size = 8
1110
+
1111
+ # We make the minimum height and width of sample for tiling half that of the generally supported
1112
+ self.tile_sample_min_height = sample_height // 2
1113
+ self.tile_sample_min_width = sample_width // 2
1114
+ self.tile_latent_min_height = int(
1115
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1116
+ )
1117
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1118
+
1119
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1120
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1121
+ # and so the tiling implementation has only been tested on those specific resolutions.
1122
+ self.tile_overlap_factor_height = 1 / 6
1123
+ self.tile_overlap_factor_width = 1 / 5
1124
+
1125
+ def _set_gradient_checkpointing(self, module, value=False):
1126
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1127
+ module.gradient_checkpointing = value
1128
+
1129
+ def enable_tiling(
1130
+ self,
1131
+ tile_sample_min_height: Optional[int] = None,
1132
+ tile_sample_min_width: Optional[int] = None,
1133
+ tile_overlap_factor_height: Optional[float] = None,
1134
+ tile_overlap_factor_width: Optional[float] = None,
1135
+ ) -> None:
1136
+ r"""
1137
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1138
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1139
+ processing larger images.
1140
+
1141
+ Args:
1142
+ tile_sample_min_height (`int`, *optional*):
1143
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1144
+ tile_sample_min_width (`int`, *optional*):
1145
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1146
+ tile_overlap_factor_height (`int`, *optional*):
1147
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1148
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1149
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1150
+ tile_overlap_factor_width (`int`, *optional*):
1151
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1152
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1153
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1154
+ """
1155
+ self.use_tiling = True
1156
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1157
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1158
+ self.tile_latent_min_height = int(
1159
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1160
+ )
1161
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1162
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1163
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1164
+
1165
+ def disable_tiling(self) -> None:
1166
+ r"""
1167
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1168
+ decoding in one step.
1169
+ """
1170
+ self.use_tiling = False
1171
+
1172
+ def enable_slicing(self) -> None:
1173
+ r"""
1174
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1175
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1176
+ """
1177
+ self.use_slicing = True
1178
+
1179
+ def disable_slicing(self) -> None:
1180
+ r"""
1181
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1182
+ decoding in one step.
1183
+ """
1184
+ self.use_slicing = False
1185
+
1186
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1187
+ batch_size, num_channels, num_frames, height, width = x.shape
1188
+
1189
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1190
+ return self.tiled_encode(x)
1191
+
1192
+ frame_batch_size = self.num_sample_frames_batch_size
1193
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1194
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1195
+ num_batches = max(num_frames // frame_batch_size, 1)
1196
+ conv_cache = None
1197
+ enc = []
1198
+
1199
+ for i in range(num_batches):
1200
+ remaining_frames = num_frames % frame_batch_size
1201
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1202
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1203
+ x_intermediate = x[:, :, start_frame:end_frame]
1204
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1205
+ if self.quant_conv is not None:
1206
+ x_intermediate = self.quant_conv(x_intermediate)
1207
+ enc.append(x_intermediate)
1208
+
1209
+ enc = torch.cat(enc, dim=2)
1210
+ return enc
1211
+
1212
+ @apply_forward_hook
1213
+ def encode(
1214
+ self, x: torch.Tensor, return_dict: bool = True
1215
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1216
+ """
1217
+ Encode a batch of images into latents.
1218
+
1219
+ Args:
1220
+ x (`torch.Tensor`): Input batch of images.
1221
+ return_dict (`bool`, *optional*, defaults to `True`):
1222
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1223
+
1224
+ Returns:
1225
+ The latent representations of the encoded videos. If `return_dict` is True, a
1226
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1227
+ """
1228
+ if self.use_slicing and x.shape[0] > 1:
1229
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1230
+ h = torch.cat(encoded_slices)
1231
+ else:
1232
+ h = self._encode(x)
1233
+
1234
+ posterior = DiagonalGaussianDistribution(h)
1235
+
1236
+ if not return_dict:
1237
+ return (posterior,)
1238
+ return AutoencoderKLOutput(latent_dist=posterior)
1239
+
1240
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1241
+ batch_size, num_channels, num_frames, height, width = z.shape
1242
+
1243
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1244
+ return self.tiled_decode(z, return_dict=return_dict)
1245
+
1246
+ frame_batch_size = self.num_latent_frames_batch_size
1247
+ num_batches = max(num_frames // frame_batch_size, 1)
1248
+ conv_cache = None
1249
+ dec = []
1250
+
1251
+ for i in range(num_batches):
1252
+ remaining_frames = num_frames % frame_batch_size
1253
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1254
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1255
+ z_intermediate = z[:, :, start_frame:end_frame]
1256
+ if self.post_quant_conv is not None:
1257
+ z_intermediate = self.post_quant_conv(z_intermediate)
1258
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1259
+ dec.append(z_intermediate)
1260
+
1261
+ dec = torch.cat(dec, dim=2)
1262
+
1263
+ if not return_dict:
1264
+ return (dec,)
1265
+
1266
+ return DecoderOutput(sample=dec)
1267
+
1268
+ @apply_forward_hook
1269
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1270
+ """
1271
+ Decode a batch of images.
1272
+
1273
+ Args:
1274
+ z (`torch.Tensor`): Input batch of latent vectors.
1275
+ return_dict (`bool`, *optional*, defaults to `True`):
1276
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1277
+
1278
+ Returns:
1279
+ [`~models.vae.DecoderOutput`] or `tuple`:
1280
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1281
+ returned.
1282
+ """
1283
+ if self.use_slicing and z.shape[0] > 1:
1284
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1285
+ decoded = torch.cat(decoded_slices)
1286
+ else:
1287
+ decoded = self._decode(z).sample
1288
+
1289
+ if not return_dict:
1290
+ return (decoded,)
1291
+ return DecoderOutput(sample=decoded)
1292
+
1293
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1294
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1295
+ for y in range(blend_extent):
1296
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1297
+ y / blend_extent
1298
+ )
1299
+ return b
1300
+
1301
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1302
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1303
+ for x in range(blend_extent):
1304
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1305
+ x / blend_extent
1306
+ )
1307
+ return b
1308
+
1309
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1310
+ r"""Encode a batch of images using a tiled encoder.
1311
+
1312
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1313
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1314
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1315
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1316
+ output, but they should be much less noticeable.
1317
+
1318
+ Args:
1319
+ x (`torch.Tensor`): Input batch of videos.
1320
+
1321
+ Returns:
1322
+ `torch.Tensor`:
1323
+ The latent representation of the encoded videos.
1324
+ """
1325
+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1326
+ batch_size, num_channels, num_frames, height, width = x.shape
1327
+
1328
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1329
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1330
+ blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1331
+ blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1332
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
1333
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
1334
+ frame_batch_size = self.num_sample_frames_batch_size
1335
+
1336
+ # Split x into overlapping tiles and encode them separately.
1337
+ # The tiles have an overlap to avoid seams between tiles.
1338
+ rows = []
1339
+ for i in range(0, height, overlap_height):
1340
+ row = []
1341
+ for j in range(0, width, overlap_width):
1342
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1343
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1344
+ num_batches = max(num_frames // frame_batch_size, 1)
1345
+ conv_cache = None
1346
+ time = []
1347
+
1348
+ for k in range(num_batches):
1349
+ remaining_frames = num_frames % frame_batch_size
1350
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1351
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1352
+ tile = x[
1353
+ :,
1354
+ :,
1355
+ start_frame:end_frame,
1356
+ i : i + self.tile_sample_min_height,
1357
+ j : j + self.tile_sample_min_width,
1358
+ ]
1359
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1360
+ if self.quant_conv is not None:
1361
+ tile = self.quant_conv(tile)
1362
+ time.append(tile)
1363
+
1364
+ row.append(torch.cat(time, dim=2))
1365
+ rows.append(row)
1366
+
1367
+ result_rows = []
1368
+ for i, row in enumerate(rows):
1369
+ result_row = []
1370
+ for j, tile in enumerate(row):
1371
+ # blend the above tile and the left tile
1372
+ # to the current tile and add the current tile to the result row
1373
+ if i > 0:
1374
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1375
+ if j > 0:
1376
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1377
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1378
+ result_rows.append(torch.cat(result_row, dim=4))
1379
+
1380
+ enc = torch.cat(result_rows, dim=3)
1381
+ return enc
1382
+
1383
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1384
+ r"""
1385
+ Decode a batch of images using a tiled decoder.
1386
+
1387
+ Args:
1388
+ z (`torch.Tensor`): Input batch of latent vectors.
1389
+ return_dict (`bool`, *optional*, defaults to `True`):
1390
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1391
+
1392
+ Returns:
1393
+ [`~models.vae.DecoderOutput`] or `tuple`:
1394
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1395
+ returned.
1396
+ """
1397
+ # Rough memory assessment:
1398
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1399
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1400
+ # - Assume fp16 (2 bytes per value).
1401
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1402
+ #
1403
+ # Memory assessment when using tiling:
1404
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1405
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1406
+
1407
+ batch_size, num_channels, num_frames, height, width = z.shape
1408
+
1409
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1410
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1411
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1412
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1413
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1414
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1415
+ frame_batch_size = self.num_latent_frames_batch_size
1416
+
1417
+ # Split z into overlapping tiles and decode them separately.
1418
+ # The tiles have an overlap to avoid seams between tiles.
1419
+ rows = []
1420
+ for i in range(0, height, overlap_height):
1421
+ row = []
1422
+ for j in range(0, width, overlap_width):
1423
+ num_batches = max(num_frames // frame_batch_size, 1)
1424
+ conv_cache = None
1425
+ time = []
1426
+
1427
+ for k in range(num_batches):
1428
+ remaining_frames = num_frames % frame_batch_size
1429
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1430
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1431
+ tile = z[
1432
+ :,
1433
+ :,
1434
+ start_frame:end_frame,
1435
+ i : i + self.tile_latent_min_height,
1436
+ j : j + self.tile_latent_min_width,
1437
+ ]
1438
+ if self.post_quant_conv is not None:
1439
+ tile = self.post_quant_conv(tile)
1440
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1441
+ time.append(tile)
1442
+
1443
+ row.append(torch.cat(time, dim=2))
1444
+ rows.append(row)
1445
+
1446
+ result_rows = []
1447
+ for i, row in enumerate(rows):
1448
+ result_row = []
1449
+ for j, tile in enumerate(row):
1450
+ # blend the above tile and the left tile
1451
+ # to the current tile and add the current tile to the result row
1452
+ if i > 0:
1453
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1454
+ if j > 0:
1455
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1456
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1457
+ result_rows.append(torch.cat(result_row, dim=4))
1458
+
1459
+ dec = torch.cat(result_rows, dim=3)
1460
+
1461
+ if not return_dict:
1462
+ return (dec,)
1463
+
1464
+ return DecoderOutput(sample=dec)
1465
+
1466
+ def forward(
1467
+ self,
1468
+ sample: torch.Tensor,
1469
+ sample_posterior: bool = False,
1470
+ return_dict: bool = True,
1471
+ generator: Optional[torch.Generator] = None,
1472
+ ) -> Union[torch.Tensor, torch.Tensor]:
1473
+ x = sample
1474
+ posterior = self.encode(x).latent_dist
1475
+ if sample_posterior:
1476
+ z = posterior.sample(generator=generator)
1477
+ else:
1478
+ z = posterior.mode()
1479
+ dec = self.decode(z).sample
1480
+ if not return_dict:
1481
+ return (dec,)
1482
+ return DecoderOutput(sample=dec)
icedit/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py ADDED
@@ -0,0 +1,1176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...utils import is_torch_version, logging
25
+ from ...utils.accelerate_utils import apply_forward_hook
26
+ from ..activations import get_activation
27
+ from ..attention_processor import Attention
28
+ from ..modeling_outputs import AutoencoderKLOutput
29
+ from ..modeling_utils import ModelMixin
30
+ from .vae import DecoderOutput, DiagonalGaussianDistribution
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ def prepare_causal_attention_mask(
37
+ num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
38
+ ) -> torch.Tensor:
39
+ seq_len = num_frames * height_width
40
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
41
+ for i in range(seq_len):
42
+ i_frame = i // height_width
43
+ mask[i, : (i_frame + 1) * height_width] = 0
44
+ if batch_size is not None:
45
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
46
+ return mask
47
+
48
+
49
+ class HunyuanVideoCausalConv3d(nn.Module):
50
+ def __init__(
51
+ self,
52
+ in_channels: int,
53
+ out_channels: int,
54
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
55
+ stride: Union[int, Tuple[int, int, int]] = 1,
56
+ padding: Union[int, Tuple[int, int, int]] = 0,
57
+ dilation: Union[int, Tuple[int, int, int]] = 1,
58
+ bias: bool = True,
59
+ pad_mode: str = "replicate",
60
+ ) -> None:
61
+ super().__init__()
62
+
63
+ kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
64
+
65
+ self.pad_mode = pad_mode
66
+ self.time_causal_padding = (
67
+ kernel_size[0] // 2,
68
+ kernel_size[0] // 2,
69
+ kernel_size[1] // 2,
70
+ kernel_size[1] // 2,
71
+ kernel_size[2] - 1,
72
+ 0,
73
+ )
74
+
75
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
76
+
77
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
78
+ hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
79
+ return self.conv(hidden_states)
80
+
81
+
82
+ class HunyuanVideoUpsampleCausal3D(nn.Module):
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: Optional[int] = None,
87
+ kernel_size: int = 3,
88
+ stride: int = 1,
89
+ bias: bool = True,
90
+ upsample_factor: Tuple[float, float, float] = (2, 2, 2),
91
+ ) -> None:
92
+ super().__init__()
93
+
94
+ out_channels = out_channels or in_channels
95
+ self.upsample_factor = upsample_factor
96
+
97
+ self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias)
98
+
99
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
100
+ num_frames = hidden_states.size(2)
101
+
102
+ first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)
103
+ first_frame = F.interpolate(
104
+ first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest"
105
+ ).unsqueeze(2)
106
+
107
+ if num_frames > 1:
108
+ # See: https://github.com/pytorch/pytorch/issues/81665
109
+ # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate
110
+ # is fixed, this will raise either a runtime error, or fail silently with bad outputs.
111
+ # If you are encountering an error here, make sure to try running encoding/decoding with
112
+ # `vae.enable_tiling()` first. If that doesn't work, open an issue at:
113
+ # https://github.com/huggingface/diffusers/issues
114
+ other_frames = other_frames.contiguous()
115
+ other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest")
116
+ hidden_states = torch.cat((first_frame, other_frames), dim=2)
117
+ else:
118
+ hidden_states = first_frame
119
+
120
+ hidden_states = self.conv(hidden_states)
121
+ return hidden_states
122
+
123
+
124
+ class HunyuanVideoDownsampleCausal3D(nn.Module):
125
+ def __init__(
126
+ self,
127
+ channels: int,
128
+ out_channels: Optional[int] = None,
129
+ padding: int = 1,
130
+ kernel_size: int = 3,
131
+ bias: bool = True,
132
+ stride=2,
133
+ ) -> None:
134
+ super().__init__()
135
+ out_channels = out_channels or channels
136
+
137
+ self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias)
138
+
139
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
140
+ hidden_states = self.conv(hidden_states)
141
+ return hidden_states
142
+
143
+
144
+ class HunyuanVideoResnetBlockCausal3D(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels: int,
148
+ out_channels: Optional[int] = None,
149
+ dropout: float = 0.0,
150
+ groups: int = 32,
151
+ eps: float = 1e-6,
152
+ non_linearity: str = "swish",
153
+ ) -> None:
154
+ super().__init__()
155
+ out_channels = out_channels or in_channels
156
+
157
+ self.nonlinearity = get_activation(non_linearity)
158
+
159
+ self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True)
160
+ self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0)
161
+
162
+ self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True)
163
+ self.dropout = nn.Dropout(dropout)
164
+ self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0)
165
+
166
+ self.conv_shortcut = None
167
+ if in_channels != out_channels:
168
+ self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0)
169
+
170
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
171
+ hidden_states = hidden_states.contiguous()
172
+ residual = hidden_states
173
+
174
+ hidden_states = self.norm1(hidden_states)
175
+ hidden_states = self.nonlinearity(hidden_states)
176
+ hidden_states = self.conv1(hidden_states)
177
+
178
+ hidden_states = self.norm2(hidden_states)
179
+ hidden_states = self.nonlinearity(hidden_states)
180
+ hidden_states = self.dropout(hidden_states)
181
+ hidden_states = self.conv2(hidden_states)
182
+
183
+ if self.conv_shortcut is not None:
184
+ residual = self.conv_shortcut(residual)
185
+
186
+ hidden_states = hidden_states + residual
187
+ return hidden_states
188
+
189
+
190
+ class HunyuanVideoMidBlock3D(nn.Module):
191
+ def __init__(
192
+ self,
193
+ in_channels: int,
194
+ dropout: float = 0.0,
195
+ num_layers: int = 1,
196
+ resnet_eps: float = 1e-6,
197
+ resnet_act_fn: str = "swish",
198
+ resnet_groups: int = 32,
199
+ add_attention: bool = True,
200
+ attention_head_dim: int = 1,
201
+ ) -> None:
202
+ super().__init__()
203
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
204
+ self.add_attention = add_attention
205
+
206
+ # There is always at least one resnet
207
+ resnets = [
208
+ HunyuanVideoResnetBlockCausal3D(
209
+ in_channels=in_channels,
210
+ out_channels=in_channels,
211
+ eps=resnet_eps,
212
+ groups=resnet_groups,
213
+ dropout=dropout,
214
+ non_linearity=resnet_act_fn,
215
+ )
216
+ ]
217
+ attentions = []
218
+
219
+ for _ in range(num_layers):
220
+ if self.add_attention:
221
+ attentions.append(
222
+ Attention(
223
+ in_channels,
224
+ heads=in_channels // attention_head_dim,
225
+ dim_head=attention_head_dim,
226
+ eps=resnet_eps,
227
+ norm_num_groups=resnet_groups,
228
+ residual_connection=True,
229
+ bias=True,
230
+ upcast_softmax=True,
231
+ _from_deprecated_attn_block=True,
232
+ )
233
+ )
234
+ else:
235
+ attentions.append(None)
236
+
237
+ resnets.append(
238
+ HunyuanVideoResnetBlockCausal3D(
239
+ in_channels=in_channels,
240
+ out_channels=in_channels,
241
+ eps=resnet_eps,
242
+ groups=resnet_groups,
243
+ dropout=dropout,
244
+ non_linearity=resnet_act_fn,
245
+ )
246
+ )
247
+
248
+ self.attentions = nn.ModuleList(attentions)
249
+ self.resnets = nn.ModuleList(resnets)
250
+
251
+ self.gradient_checkpointing = False
252
+
253
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
254
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
255
+
256
+ def create_custom_forward(module, return_dict=None):
257
+ def custom_forward(*inputs):
258
+ if return_dict is not None:
259
+ return module(*inputs, return_dict=return_dict)
260
+ else:
261
+ return module(*inputs)
262
+
263
+ return custom_forward
264
+
265
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
266
+
267
+ hidden_states = torch.utils.checkpoint.checkpoint(
268
+ create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs
269
+ )
270
+
271
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
272
+ if attn is not None:
273
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
274
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
275
+ attention_mask = prepare_causal_attention_mask(
276
+ num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
277
+ )
278
+ hidden_states = attn(hidden_states, attention_mask=attention_mask)
279
+ hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
280
+
281
+ hidden_states = torch.utils.checkpoint.checkpoint(
282
+ create_custom_forward(resnet), hidden_states, **ckpt_kwargs
283
+ )
284
+
285
+ else:
286
+ hidden_states = self.resnets[0](hidden_states)
287
+
288
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
289
+ if attn is not None:
290
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
291
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
292
+ attention_mask = prepare_causal_attention_mask(
293
+ num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
294
+ )
295
+ hidden_states = attn(hidden_states, attention_mask=attention_mask)
296
+ hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
297
+
298
+ hidden_states = resnet(hidden_states)
299
+
300
+ return hidden_states
301
+
302
+
303
+ class HunyuanVideoDownBlock3D(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channels: int,
307
+ out_channels: int,
308
+ dropout: float = 0.0,
309
+ num_layers: int = 1,
310
+ resnet_eps: float = 1e-6,
311
+ resnet_act_fn: str = "swish",
312
+ resnet_groups: int = 32,
313
+ add_downsample: bool = True,
314
+ downsample_stride: int = 2,
315
+ downsample_padding: int = 1,
316
+ ) -> None:
317
+ super().__init__()
318
+ resnets = []
319
+
320
+ for i in range(num_layers):
321
+ in_channels = in_channels if i == 0 else out_channels
322
+ resnets.append(
323
+ HunyuanVideoResnetBlockCausal3D(
324
+ in_channels=in_channels,
325
+ out_channels=out_channels,
326
+ eps=resnet_eps,
327
+ groups=resnet_groups,
328
+ dropout=dropout,
329
+ non_linearity=resnet_act_fn,
330
+ )
331
+ )
332
+
333
+ self.resnets = nn.ModuleList(resnets)
334
+
335
+ if add_downsample:
336
+ self.downsamplers = nn.ModuleList(
337
+ [
338
+ HunyuanVideoDownsampleCausal3D(
339
+ out_channels,
340
+ out_channels=out_channels,
341
+ padding=downsample_padding,
342
+ stride=downsample_stride,
343
+ )
344
+ ]
345
+ )
346
+ else:
347
+ self.downsamplers = None
348
+
349
+ self.gradient_checkpointing = False
350
+
351
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
352
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
353
+
354
+ def create_custom_forward(module, return_dict=None):
355
+ def custom_forward(*inputs):
356
+ if return_dict is not None:
357
+ return module(*inputs, return_dict=return_dict)
358
+ else:
359
+ return module(*inputs)
360
+
361
+ return custom_forward
362
+
363
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
364
+
365
+ for resnet in self.resnets:
366
+ hidden_states = torch.utils.checkpoint.checkpoint(
367
+ create_custom_forward(resnet), hidden_states, **ckpt_kwargs
368
+ )
369
+ else:
370
+ for resnet in self.resnets:
371
+ hidden_states = resnet(hidden_states)
372
+
373
+ if self.downsamplers is not None:
374
+ for downsampler in self.downsamplers:
375
+ hidden_states = downsampler(hidden_states)
376
+
377
+ return hidden_states
378
+
379
+
380
+ class HunyuanVideoUpBlock3D(nn.Module):
381
+ def __init__(
382
+ self,
383
+ in_channels: int,
384
+ out_channels: int,
385
+ dropout: float = 0.0,
386
+ num_layers: int = 1,
387
+ resnet_eps: float = 1e-6,
388
+ resnet_act_fn: str = "swish",
389
+ resnet_groups: int = 32,
390
+ add_upsample: bool = True,
391
+ upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
392
+ ) -> None:
393
+ super().__init__()
394
+ resnets = []
395
+
396
+ for i in range(num_layers):
397
+ input_channels = in_channels if i == 0 else out_channels
398
+
399
+ resnets.append(
400
+ HunyuanVideoResnetBlockCausal3D(
401
+ in_channels=input_channels,
402
+ out_channels=out_channels,
403
+ eps=resnet_eps,
404
+ groups=resnet_groups,
405
+ dropout=dropout,
406
+ non_linearity=resnet_act_fn,
407
+ )
408
+ )
409
+
410
+ self.resnets = nn.ModuleList(resnets)
411
+
412
+ if add_upsample:
413
+ self.upsamplers = nn.ModuleList(
414
+ [
415
+ HunyuanVideoUpsampleCausal3D(
416
+ out_channels,
417
+ out_channels=out_channels,
418
+ upsample_factor=upsample_scale_factor,
419
+ )
420
+ ]
421
+ )
422
+ else:
423
+ self.upsamplers = None
424
+
425
+ self.gradient_checkpointing = False
426
+
427
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
428
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
429
+
430
+ def create_custom_forward(module, return_dict=None):
431
+ def custom_forward(*inputs):
432
+ if return_dict is not None:
433
+ return module(*inputs, return_dict=return_dict)
434
+ else:
435
+ return module(*inputs)
436
+
437
+ return custom_forward
438
+
439
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
440
+
441
+ for resnet in self.resnets:
442
+ hidden_states = torch.utils.checkpoint.checkpoint(
443
+ create_custom_forward(resnet), hidden_states, **ckpt_kwargs
444
+ )
445
+
446
+ else:
447
+ for resnet in self.resnets:
448
+ hidden_states = resnet(hidden_states)
449
+
450
+ if self.upsamplers is not None:
451
+ for upsampler in self.upsamplers:
452
+ hidden_states = upsampler(hidden_states)
453
+
454
+ return hidden_states
455
+
456
+
457
+ class HunyuanVideoEncoder3D(nn.Module):
458
+ r"""
459
+ Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
460
+ """
461
+
462
+ def __init__(
463
+ self,
464
+ in_channels: int = 3,
465
+ out_channels: int = 3,
466
+ down_block_types: Tuple[str, ...] = (
467
+ "HunyuanVideoDownBlock3D",
468
+ "HunyuanVideoDownBlock3D",
469
+ "HunyuanVideoDownBlock3D",
470
+ "HunyuanVideoDownBlock3D",
471
+ ),
472
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
473
+ layers_per_block: int = 2,
474
+ norm_num_groups: int = 32,
475
+ act_fn: str = "silu",
476
+ double_z: bool = True,
477
+ mid_block_add_attention=True,
478
+ temporal_compression_ratio: int = 4,
479
+ spatial_compression_ratio: int = 8,
480
+ ) -> None:
481
+ super().__init__()
482
+
483
+ self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
484
+ self.mid_block = None
485
+ self.down_blocks = nn.ModuleList([])
486
+
487
+ output_channel = block_out_channels[0]
488
+ for i, down_block_type in enumerate(down_block_types):
489
+ if down_block_type != "HunyuanVideoDownBlock3D":
490
+ raise ValueError(f"Unsupported down_block_type: {down_block_type}")
491
+
492
+ input_channel = output_channel
493
+ output_channel = block_out_channels[i]
494
+ is_final_block = i == len(block_out_channels) - 1
495
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
496
+ num_time_downsample_layers = int(np.log2(temporal_compression_ratio))
497
+
498
+ if temporal_compression_ratio == 4:
499
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
500
+ add_time_downsample = bool(
501
+ i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block
502
+ )
503
+ elif temporal_compression_ratio == 8:
504
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
505
+ add_time_downsample = bool(i < num_time_downsample_layers)
506
+ else:
507
+ raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}")
508
+
509
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
510
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
511
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
512
+
513
+ down_block = HunyuanVideoDownBlock3D(
514
+ num_layers=layers_per_block,
515
+ in_channels=input_channel,
516
+ out_channels=output_channel,
517
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
518
+ resnet_eps=1e-6,
519
+ resnet_act_fn=act_fn,
520
+ resnet_groups=norm_num_groups,
521
+ downsample_stride=downsample_stride,
522
+ downsample_padding=0,
523
+ )
524
+
525
+ self.down_blocks.append(down_block)
526
+
527
+ self.mid_block = HunyuanVideoMidBlock3D(
528
+ in_channels=block_out_channels[-1],
529
+ resnet_eps=1e-6,
530
+ resnet_act_fn=act_fn,
531
+ attention_head_dim=block_out_channels[-1],
532
+ resnet_groups=norm_num_groups,
533
+ add_attention=mid_block_add_attention,
534
+ )
535
+
536
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
537
+ self.conv_act = nn.SiLU()
538
+
539
+ conv_out_channels = 2 * out_channels if double_z else out_channels
540
+ self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
541
+
542
+ self.gradient_checkpointing = False
543
+
544
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
545
+ hidden_states = self.conv_in(hidden_states)
546
+
547
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
548
+
549
+ def create_custom_forward(module, return_dict=None):
550
+ def custom_forward(*inputs):
551
+ if return_dict is not None:
552
+ return module(*inputs, return_dict=return_dict)
553
+ else:
554
+ return module(*inputs)
555
+
556
+ return custom_forward
557
+
558
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
559
+
560
+ for down_block in self.down_blocks:
561
+ hidden_states = torch.utils.checkpoint.checkpoint(
562
+ create_custom_forward(down_block), hidden_states, **ckpt_kwargs
563
+ )
564
+
565
+ hidden_states = torch.utils.checkpoint.checkpoint(
566
+ create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
567
+ )
568
+ else:
569
+ for down_block in self.down_blocks:
570
+ hidden_states = down_block(hidden_states)
571
+
572
+ hidden_states = self.mid_block(hidden_states)
573
+
574
+ hidden_states = self.conv_norm_out(hidden_states)
575
+ hidden_states = self.conv_act(hidden_states)
576
+ hidden_states = self.conv_out(hidden_states)
577
+
578
+ return hidden_states
579
+
580
+
581
+ class HunyuanVideoDecoder3D(nn.Module):
582
+ r"""
583
+ Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
584
+ """
585
+
586
+ def __init__(
587
+ self,
588
+ in_channels: int = 3,
589
+ out_channels: int = 3,
590
+ up_block_types: Tuple[str, ...] = (
591
+ "HunyuanVideoUpBlock3D",
592
+ "HunyuanVideoUpBlock3D",
593
+ "HunyuanVideoUpBlock3D",
594
+ "HunyuanVideoUpBlock3D",
595
+ ),
596
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
597
+ layers_per_block: int = 2,
598
+ norm_num_groups: int = 32,
599
+ act_fn: str = "silu",
600
+ mid_block_add_attention=True,
601
+ time_compression_ratio: int = 4,
602
+ spatial_compression_ratio: int = 8,
603
+ ):
604
+ super().__init__()
605
+ self.layers_per_block = layers_per_block
606
+
607
+ self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
608
+ self.up_blocks = nn.ModuleList([])
609
+
610
+ # mid
611
+ self.mid_block = HunyuanVideoMidBlock3D(
612
+ in_channels=block_out_channels[-1],
613
+ resnet_eps=1e-6,
614
+ resnet_act_fn=act_fn,
615
+ attention_head_dim=block_out_channels[-1],
616
+ resnet_groups=norm_num_groups,
617
+ add_attention=mid_block_add_attention,
618
+ )
619
+
620
+ # up
621
+ reversed_block_out_channels = list(reversed(block_out_channels))
622
+ output_channel = reversed_block_out_channels[0]
623
+ for i, up_block_type in enumerate(up_block_types):
624
+ if up_block_type != "HunyuanVideoUpBlock3D":
625
+ raise ValueError(f"Unsupported up_block_type: {up_block_type}")
626
+
627
+ prev_output_channel = output_channel
628
+ output_channel = reversed_block_out_channels[i]
629
+ is_final_block = i == len(block_out_channels) - 1
630
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
631
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
632
+
633
+ if time_compression_ratio == 4:
634
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
635
+ add_time_upsample = bool(
636
+ i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block
637
+ )
638
+ else:
639
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
640
+
641
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
642
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
643
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
644
+
645
+ up_block = HunyuanVideoUpBlock3D(
646
+ num_layers=self.layers_per_block + 1,
647
+ in_channels=prev_output_channel,
648
+ out_channels=output_channel,
649
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
650
+ upsample_scale_factor=upsample_scale_factor,
651
+ resnet_eps=1e-6,
652
+ resnet_act_fn=act_fn,
653
+ resnet_groups=norm_num_groups,
654
+ )
655
+
656
+ self.up_blocks.append(up_block)
657
+ prev_output_channel = output_channel
658
+
659
+ # out
660
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
661
+ self.conv_act = nn.SiLU()
662
+ self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
663
+
664
+ self.gradient_checkpointing = False
665
+
666
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
667
+ hidden_states = self.conv_in(hidden_states)
668
+
669
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
670
+
671
+ def create_custom_forward(module, return_dict=None):
672
+ def custom_forward(*inputs):
673
+ if return_dict is not None:
674
+ return module(*inputs, return_dict=return_dict)
675
+ else:
676
+ return module(*inputs)
677
+
678
+ return custom_forward
679
+
680
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
681
+
682
+ hidden_states = torch.utils.checkpoint.checkpoint(
683
+ create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
684
+ )
685
+
686
+ for up_block in self.up_blocks:
687
+ hidden_states = torch.utils.checkpoint.checkpoint(
688
+ create_custom_forward(up_block), hidden_states, **ckpt_kwargs
689
+ )
690
+ else:
691
+ hidden_states = self.mid_block(hidden_states)
692
+
693
+ for up_block in self.up_blocks:
694
+ hidden_states = up_block(hidden_states)
695
+
696
+ # post-process
697
+ hidden_states = self.conv_norm_out(hidden_states)
698
+ hidden_states = self.conv_act(hidden_states)
699
+ hidden_states = self.conv_out(hidden_states)
700
+
701
+ return hidden_states
702
+
703
+
704
+ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
705
+ r"""
706
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
707
+ Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
708
+
709
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
710
+ for all models (such as downloading or saving).
711
+ """
712
+
713
+ _supports_gradient_checkpointing = True
714
+
715
+ @register_to_config
716
+ def __init__(
717
+ self,
718
+ in_channels: int = 3,
719
+ out_channels: int = 3,
720
+ latent_channels: int = 16,
721
+ down_block_types: Tuple[str, ...] = (
722
+ "HunyuanVideoDownBlock3D",
723
+ "HunyuanVideoDownBlock3D",
724
+ "HunyuanVideoDownBlock3D",
725
+ "HunyuanVideoDownBlock3D",
726
+ ),
727
+ up_block_types: Tuple[str, ...] = (
728
+ "HunyuanVideoUpBlock3D",
729
+ "HunyuanVideoUpBlock3D",
730
+ "HunyuanVideoUpBlock3D",
731
+ "HunyuanVideoUpBlock3D",
732
+ ),
733
+ block_out_channels: Tuple[int] = (128, 256, 512, 512),
734
+ layers_per_block: int = 2,
735
+ act_fn: str = "silu",
736
+ norm_num_groups: int = 32,
737
+ scaling_factor: float = 0.476986,
738
+ spatial_compression_ratio: int = 8,
739
+ temporal_compression_ratio: int = 4,
740
+ mid_block_add_attention: bool = True,
741
+ ) -> None:
742
+ super().__init__()
743
+
744
+ self.time_compression_ratio = temporal_compression_ratio
745
+
746
+ self.encoder = HunyuanVideoEncoder3D(
747
+ in_channels=in_channels,
748
+ out_channels=latent_channels,
749
+ down_block_types=down_block_types,
750
+ block_out_channels=block_out_channels,
751
+ layers_per_block=layers_per_block,
752
+ norm_num_groups=norm_num_groups,
753
+ act_fn=act_fn,
754
+ double_z=True,
755
+ mid_block_add_attention=mid_block_add_attention,
756
+ temporal_compression_ratio=temporal_compression_ratio,
757
+ spatial_compression_ratio=spatial_compression_ratio,
758
+ )
759
+
760
+ self.decoder = HunyuanVideoDecoder3D(
761
+ in_channels=latent_channels,
762
+ out_channels=out_channels,
763
+ up_block_types=up_block_types,
764
+ block_out_channels=block_out_channels,
765
+ layers_per_block=layers_per_block,
766
+ norm_num_groups=norm_num_groups,
767
+ act_fn=act_fn,
768
+ time_compression_ratio=temporal_compression_ratio,
769
+ spatial_compression_ratio=spatial_compression_ratio,
770
+ mid_block_add_attention=mid_block_add_attention,
771
+ )
772
+
773
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
774
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
775
+
776
+ self.spatial_compression_ratio = spatial_compression_ratio
777
+ self.temporal_compression_ratio = temporal_compression_ratio
778
+
779
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
780
+ # to perform decoding of a single video latent at a time.
781
+ self.use_slicing = False
782
+
783
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
784
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
785
+ # intermediate tiles together, the memory requirement can be lowered.
786
+ self.use_tiling = False
787
+
788
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
789
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
790
+ self.use_framewise_encoding = True
791
+ self.use_framewise_decoding = True
792
+
793
+ # The minimal tile height and width for spatial tiling to be used
794
+ self.tile_sample_min_height = 256
795
+ self.tile_sample_min_width = 256
796
+ self.tile_sample_min_num_frames = 16
797
+
798
+ # The minimal distance between two spatial tiles
799
+ self.tile_sample_stride_height = 192
800
+ self.tile_sample_stride_width = 192
801
+ self.tile_sample_stride_num_frames = 12
802
+
803
+ def _set_gradient_checkpointing(self, module, value=False):
804
+ if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
805
+ module.gradient_checkpointing = value
806
+
807
+ def enable_tiling(
808
+ self,
809
+ tile_sample_min_height: Optional[int] = None,
810
+ tile_sample_min_width: Optional[int] = None,
811
+ tile_sample_min_num_frames: Optional[int] = None,
812
+ tile_sample_stride_height: Optional[float] = None,
813
+ tile_sample_stride_width: Optional[float] = None,
814
+ tile_sample_stride_num_frames: Optional[float] = None,
815
+ ) -> None:
816
+ r"""
817
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
818
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
819
+ processing larger images.
820
+
821
+ Args:
822
+ tile_sample_min_height (`int`, *optional*):
823
+ The minimum height required for a sample to be separated into tiles across the height dimension.
824
+ tile_sample_min_width (`int`, *optional*):
825
+ The minimum width required for a sample to be separated into tiles across the width dimension.
826
+ tile_sample_min_num_frames (`int`, *optional*):
827
+ The minimum number of frames required for a sample to be separated into tiles across the frame
828
+ dimension.
829
+ tile_sample_stride_height (`int`, *optional*):
830
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
831
+ no tiling artifacts produced across the height dimension.
832
+ tile_sample_stride_width (`int`, *optional*):
833
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
834
+ artifacts produced across the width dimension.
835
+ tile_sample_stride_num_frames (`int`, *optional*):
836
+ The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts
837
+ produced across the frame dimension.
838
+ """
839
+ self.use_tiling = True
840
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
841
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
842
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
843
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
844
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
845
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
846
+
847
+ def disable_tiling(self) -> None:
848
+ r"""
849
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
850
+ decoding in one step.
851
+ """
852
+ self.use_tiling = False
853
+
854
+ def enable_slicing(self) -> None:
855
+ r"""
856
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
857
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
858
+ """
859
+ self.use_slicing = True
860
+
861
+ def disable_slicing(self) -> None:
862
+ r"""
863
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
864
+ decoding in one step.
865
+ """
866
+ self.use_slicing = False
867
+
868
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
869
+ batch_size, num_channels, num_frames, height, width = x.shape
870
+
871
+ if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
872
+ return self._temporal_tiled_encode(x)
873
+
874
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
875
+ return self.tiled_encode(x)
876
+
877
+ x = self.encoder(x)
878
+ enc = self.quant_conv(x)
879
+ return enc
880
+
881
+ @apply_forward_hook
882
+ def encode(
883
+ self, x: torch.Tensor, return_dict: bool = True
884
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
885
+ r"""
886
+ Encode a batch of images into latents.
887
+
888
+ Args:
889
+ x (`torch.Tensor`): Input batch of images.
890
+ return_dict (`bool`, *optional*, defaults to `True`):
891
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
892
+
893
+ Returns:
894
+ The latent representations of the encoded videos. If `return_dict` is True, a
895
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
896
+ """
897
+ if self.use_slicing and x.shape[0] > 1:
898
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
899
+ h = torch.cat(encoded_slices)
900
+ else:
901
+ h = self._encode(x)
902
+
903
+ posterior = DiagonalGaussianDistribution(h)
904
+
905
+ if not return_dict:
906
+ return (posterior,)
907
+ return AutoencoderKLOutput(latent_dist=posterior)
908
+
909
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
910
+ batch_size, num_channels, num_frames, height, width = z.shape
911
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
912
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
913
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
914
+
915
+ if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
916
+ return self._temporal_tiled_decode(z, return_dict=return_dict)
917
+
918
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
919
+ return self.tiled_decode(z, return_dict=return_dict)
920
+
921
+ z = self.post_quant_conv(z)
922
+ dec = self.decoder(z)
923
+
924
+ if not return_dict:
925
+ return (dec,)
926
+
927
+ return DecoderOutput(sample=dec)
928
+
929
+ @apply_forward_hook
930
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
931
+ r"""
932
+ Decode a batch of images.
933
+
934
+ Args:
935
+ z (`torch.Tensor`): Input batch of latent vectors.
936
+ return_dict (`bool`, *optional*, defaults to `True`):
937
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
938
+
939
+ Returns:
940
+ [`~models.vae.DecoderOutput`] or `tuple`:
941
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
942
+ returned.
943
+ """
944
+ if self.use_slicing and z.shape[0] > 1:
945
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
946
+ decoded = torch.cat(decoded_slices)
947
+ else:
948
+ decoded = self._decode(z).sample
949
+
950
+ if not return_dict:
951
+ return (decoded,)
952
+
953
+ return DecoderOutput(sample=decoded)
954
+
955
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
956
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
957
+ for y in range(blend_extent):
958
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
959
+ y / blend_extent
960
+ )
961
+ return b
962
+
963
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
964
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
965
+ for x in range(blend_extent):
966
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
967
+ x / blend_extent
968
+ )
969
+ return b
970
+
971
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
972
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
973
+ for x in range(blend_extent):
974
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
975
+ x / blend_extent
976
+ )
977
+ return b
978
+
979
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
980
+ r"""Encode a batch of images using a tiled encoder.
981
+
982
+ Args:
983
+ x (`torch.Tensor`): Input batch of videos.
984
+
985
+ Returns:
986
+ `torch.Tensor`:
987
+ The latent representation of the encoded videos.
988
+ """
989
+ batch_size, num_channels, num_frames, height, width = x.shape
990
+ latent_height = height // self.spatial_compression_ratio
991
+ latent_width = width // self.spatial_compression_ratio
992
+
993
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
994
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
995
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
996
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
997
+
998
+ blend_height = tile_latent_min_height - tile_latent_stride_height
999
+ blend_width = tile_latent_min_width - tile_latent_stride_width
1000
+
1001
+ # Split x into overlapping tiles and encode them separately.
1002
+ # The tiles have an overlap to avoid seams between tiles.
1003
+ rows = []
1004
+ for i in range(0, height, self.tile_sample_stride_height):
1005
+ row = []
1006
+ for j in range(0, width, self.tile_sample_stride_width):
1007
+ tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1008
+ tile = self.encoder(tile)
1009
+ tile = self.quant_conv(tile)
1010
+ row.append(tile)
1011
+ rows.append(row)
1012
+
1013
+ result_rows = []
1014
+ for i, row in enumerate(rows):
1015
+ result_row = []
1016
+ for j, tile in enumerate(row):
1017
+ # blend the above tile and the left tile
1018
+ # to the current tile and add the current tile to the result row
1019
+ if i > 0:
1020
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1021
+ if j > 0:
1022
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1023
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1024
+ result_rows.append(torch.cat(result_row, dim=4))
1025
+
1026
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1027
+ return enc
1028
+
1029
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1030
+ r"""
1031
+ Decode a batch of images using a tiled decoder.
1032
+
1033
+ Args:
1034
+ z (`torch.Tensor`): Input batch of latent vectors.
1035
+ return_dict (`bool`, *optional*, defaults to `True`):
1036
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1037
+
1038
+ Returns:
1039
+ [`~models.vae.DecoderOutput`] or `tuple`:
1040
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1041
+ returned.
1042
+ """
1043
+
1044
+ batch_size, num_channels, num_frames, height, width = z.shape
1045
+ sample_height = height * self.spatial_compression_ratio
1046
+ sample_width = width * self.spatial_compression_ratio
1047
+
1048
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1049
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1050
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1051
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1052
+
1053
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1054
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1055
+
1056
+ # Split z into overlapping tiles and decode them separately.
1057
+ # The tiles have an overlap to avoid seams between tiles.
1058
+ rows = []
1059
+ for i in range(0, height, tile_latent_stride_height):
1060
+ row = []
1061
+ for j in range(0, width, tile_latent_stride_width):
1062
+ tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1063
+ tile = self.post_quant_conv(tile)
1064
+ decoded = self.decoder(tile)
1065
+ row.append(decoded)
1066
+ rows.append(row)
1067
+
1068
+ result_rows = []
1069
+ for i, row in enumerate(rows):
1070
+ result_row = []
1071
+ for j, tile in enumerate(row):
1072
+ # blend the above tile and the left tile
1073
+ # to the current tile and add the current tile to the result row
1074
+ if i > 0:
1075
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1076
+ if j > 0:
1077
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1078
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1079
+ result_rows.append(torch.cat(result_row, dim=-1))
1080
+
1081
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1082
+
1083
+ if not return_dict:
1084
+ return (dec,)
1085
+ return DecoderOutput(sample=dec)
1086
+
1087
+ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
1088
+ batch_size, num_channels, num_frames, height, width = x.shape
1089
+ latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
1090
+
1091
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1092
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1093
+ blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
1094
+
1095
+ row = []
1096
+ for i in range(0, num_frames, self.tile_sample_stride_num_frames):
1097
+ tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
1098
+ if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
1099
+ tile = self.tiled_encode(tile)
1100
+ else:
1101
+ tile = self.encoder(tile)
1102
+ tile = self.quant_conv(tile)
1103
+ if i > 0:
1104
+ tile = tile[:, :, 1:, :, :]
1105
+ row.append(tile)
1106
+
1107
+ result_row = []
1108
+ for i, tile in enumerate(row):
1109
+ if i > 0:
1110
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1111
+ result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
1112
+ else:
1113
+ result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
1114
+
1115
+ enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
1116
+ return enc
1117
+
1118
+ def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1119
+ batch_size, num_channels, num_frames, height, width = z.shape
1120
+ num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
1121
+
1122
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1123
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1124
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1125
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1126
+ blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
1127
+
1128
+ row = []
1129
+ for i in range(0, num_frames, tile_latent_stride_num_frames):
1130
+ tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
1131
+ if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
1132
+ decoded = self.tiled_decode(tile, return_dict=True).sample
1133
+ else:
1134
+ tile = self.post_quant_conv(tile)
1135
+ decoded = self.decoder(tile)
1136
+ if i > 0:
1137
+ decoded = decoded[:, :, 1:, :, :]
1138
+ row.append(decoded)
1139
+
1140
+ result_row = []
1141
+ for i, tile in enumerate(row):
1142
+ if i > 0:
1143
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1144
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :])
1145
+ else:
1146
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
1147
+
1148
+ dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
1149
+
1150
+ if not return_dict:
1151
+ return (dec,)
1152
+ return DecoderOutput(sample=dec)
1153
+
1154
+ def forward(
1155
+ self,
1156
+ sample: torch.Tensor,
1157
+ sample_posterior: bool = False,
1158
+ return_dict: bool = True,
1159
+ generator: Optional[torch.Generator] = None,
1160
+ ) -> Union[DecoderOutput, torch.Tensor]:
1161
+ r"""
1162
+ Args:
1163
+ sample (`torch.Tensor`): Input sample.
1164
+ sample_posterior (`bool`, *optional*, defaults to `False`):
1165
+ Whether to sample from the posterior.
1166
+ return_dict (`bool`, *optional*, defaults to `True`):
1167
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1168
+ """
1169
+ x = sample
1170
+ posterior = self.encode(x).latent_dist
1171
+ if sample_posterior:
1172
+ z = posterior.sample(generator=generator)
1173
+ else:
1174
+ z = posterior.mode()
1175
+ dec = self.decode(z, return_dict=return_dict)
1176
+ return dec
icedit/diffusers/models/autoencoders/autoencoder_kl_ltx.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Lightricks team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import FromOriginalModelMixin
23
+ from ...utils.accelerate_utils import apply_forward_hook
24
+ from ..activations import get_activation
25
+ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
26
+ from ..modeling_outputs import AutoencoderKLOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import RMSNorm
29
+ from .vae import DecoderOutput, DiagonalGaussianDistribution
30
+
31
+
32
+ class LTXVideoCausalConv3d(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels: int,
36
+ out_channels: int,
37
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
38
+ stride: Union[int, Tuple[int, int, int]] = 1,
39
+ dilation: Union[int, Tuple[int, int, int]] = 1,
40
+ groups: int = 1,
41
+ padding_mode: str = "zeros",
42
+ is_causal: bool = True,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.in_channels = in_channels
47
+ self.out_channels = out_channels
48
+ self.is_causal = is_causal
49
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
50
+
51
+ dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
52
+ stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
53
+ height_pad = self.kernel_size[1] // 2
54
+ width_pad = self.kernel_size[2] // 2
55
+ padding = (0, height_pad, width_pad)
56
+
57
+ self.conv = nn.Conv3d(
58
+ in_channels,
59
+ out_channels,
60
+ self.kernel_size,
61
+ stride=stride,
62
+ dilation=dilation,
63
+ groups=groups,
64
+ padding=padding,
65
+ padding_mode=padding_mode,
66
+ )
67
+
68
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
69
+ time_kernel_size = self.kernel_size[0]
70
+
71
+ if self.is_causal:
72
+ pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1))
73
+ hidden_states = torch.concatenate([pad_left, hidden_states], dim=2)
74
+ else:
75
+ pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
76
+ pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
77
+ hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2)
78
+
79
+ hidden_states = self.conv(hidden_states)
80
+ return hidden_states
81
+
82
+
83
+ class LTXVideoResnetBlock3d(nn.Module):
84
+ r"""
85
+ A 3D ResNet block used in the LTXVideo model.
86
+
87
+ Args:
88
+ in_channels (`int`):
89
+ Number of input channels.
90
+ out_channels (`int`, *optional*):
91
+ Number of output channels. If None, defaults to `in_channels`.
92
+ dropout (`float`, defaults to `0.0`):
93
+ Dropout rate.
94
+ eps (`float`, defaults to `1e-6`):
95
+ Epsilon value for normalization layers.
96
+ elementwise_affine (`bool`, defaults to `False`):
97
+ Whether to enable elementwise affinity in the normalization layers.
98
+ non_linearity (`str`, defaults to `"swish"`):
99
+ Activation function to use.
100
+ conv_shortcut (bool, defaults to `False`):
101
+ Whether or not to use a convolution shortcut.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: Optional[int] = None,
108
+ dropout: float = 0.0,
109
+ eps: float = 1e-6,
110
+ elementwise_affine: bool = False,
111
+ non_linearity: str = "swish",
112
+ is_causal: bool = True,
113
+ inject_noise: bool = False,
114
+ timestep_conditioning: bool = False,
115
+ ) -> None:
116
+ super().__init__()
117
+
118
+ out_channels = out_channels or in_channels
119
+
120
+ self.nonlinearity = get_activation(non_linearity)
121
+
122
+ self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
123
+ self.conv1 = LTXVideoCausalConv3d(
124
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
125
+ )
126
+
127
+ self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
128
+ self.dropout = nn.Dropout(dropout)
129
+ self.conv2 = LTXVideoCausalConv3d(
130
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
131
+ )
132
+
133
+ self.norm3 = None
134
+ self.conv_shortcut = None
135
+ if in_channels != out_channels:
136
+ self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
137
+ self.conv_shortcut = LTXVideoCausalConv3d(
138
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
139
+ )
140
+
141
+ self.per_channel_scale1 = None
142
+ self.per_channel_scale2 = None
143
+ if inject_noise:
144
+ self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
145
+ self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
146
+
147
+ self.scale_shift_table = None
148
+ if timestep_conditioning:
149
+ self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
150
+
151
+ def forward(
152
+ self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
153
+ ) -> torch.Tensor:
154
+ hidden_states = inputs
155
+
156
+ hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
157
+
158
+ if self.scale_shift_table is not None:
159
+ temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
160
+ shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
161
+ hidden_states = hidden_states * (1 + scale_1) + shift_1
162
+
163
+ hidden_states = self.nonlinearity(hidden_states)
164
+ hidden_states = self.conv1(hidden_states)
165
+
166
+ if self.per_channel_scale1 is not None:
167
+ spatial_shape = hidden_states.shape[-2:]
168
+ spatial_noise = torch.randn(
169
+ spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
170
+ )[None]
171
+ hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
172
+
173
+ hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
174
+
175
+ if self.scale_shift_table is not None:
176
+ hidden_states = hidden_states * (1 + scale_2) + shift_2
177
+
178
+ hidden_states = self.nonlinearity(hidden_states)
179
+ hidden_states = self.dropout(hidden_states)
180
+ hidden_states = self.conv2(hidden_states)
181
+
182
+ if self.per_channel_scale2 is not None:
183
+ spatial_shape = hidden_states.shape[-2:]
184
+ spatial_noise = torch.randn(
185
+ spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
186
+ )[None]
187
+ hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
188
+
189
+ if self.norm3 is not None:
190
+ inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
191
+
192
+ if self.conv_shortcut is not None:
193
+ inputs = self.conv_shortcut(inputs)
194
+
195
+ hidden_states = hidden_states + inputs
196
+ return hidden_states
197
+
198
+
199
+ class LTXVideoUpsampler3d(nn.Module):
200
+ def __init__(
201
+ self,
202
+ in_channels: int,
203
+ stride: Union[int, Tuple[int, int, int]] = 1,
204
+ is_causal: bool = True,
205
+ residual: bool = False,
206
+ upscale_factor: int = 1,
207
+ ) -> None:
208
+ super().__init__()
209
+
210
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
211
+ self.residual = residual
212
+ self.upscale_factor = upscale_factor
213
+
214
+ out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
215
+
216
+ self.conv = LTXVideoCausalConv3d(
217
+ in_channels=in_channels,
218
+ out_channels=out_channels,
219
+ kernel_size=3,
220
+ stride=1,
221
+ is_causal=is_causal,
222
+ )
223
+
224
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
225
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
226
+
227
+ if self.residual:
228
+ residual = hidden_states.reshape(
229
+ batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
230
+ )
231
+ residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
232
+ repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
233
+ residual = residual.repeat(1, repeats, 1, 1, 1)
234
+ residual = residual[:, :, self.stride[0] - 1 :]
235
+
236
+ hidden_states = self.conv(hidden_states)
237
+ hidden_states = hidden_states.reshape(
238
+ batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
239
+ )
240
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
241
+ hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
242
+
243
+ if self.residual:
244
+ hidden_states = hidden_states + residual
245
+
246
+ return hidden_states
247
+
248
+
249
+ class LTXVideoDownBlock3D(nn.Module):
250
+ r"""
251
+ Down block used in the LTXVideo model.
252
+
253
+ Args:
254
+ in_channels (`int`):
255
+ Number of input channels.
256
+ out_channels (`int`, *optional*):
257
+ Number of output channels. If None, defaults to `in_channels`.
258
+ num_layers (`int`, defaults to `1`):
259
+ Number of resnet layers.
260
+ dropout (`float`, defaults to `0.0`):
261
+ Dropout rate.
262
+ resnet_eps (`float`, defaults to `1e-6`):
263
+ Epsilon value for normalization layers.
264
+ resnet_act_fn (`str`, defaults to `"swish"`):
265
+ Activation function to use.
266
+ spatio_temporal_scale (`bool`, defaults to `True`):
267
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
268
+ Whether or not to downsample across temporal dimension.
269
+ is_causal (`bool`, defaults to `True`):
270
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
271
+ """
272
+
273
+ _supports_gradient_checkpointing = True
274
+
275
+ def __init__(
276
+ self,
277
+ in_channels: int,
278
+ out_channels: Optional[int] = None,
279
+ num_layers: int = 1,
280
+ dropout: float = 0.0,
281
+ resnet_eps: float = 1e-6,
282
+ resnet_act_fn: str = "swish",
283
+ spatio_temporal_scale: bool = True,
284
+ is_causal: bool = True,
285
+ ):
286
+ super().__init__()
287
+
288
+ out_channels = out_channels or in_channels
289
+
290
+ resnets = []
291
+ for _ in range(num_layers):
292
+ resnets.append(
293
+ LTXVideoResnetBlock3d(
294
+ in_channels=in_channels,
295
+ out_channels=in_channels,
296
+ dropout=dropout,
297
+ eps=resnet_eps,
298
+ non_linearity=resnet_act_fn,
299
+ is_causal=is_causal,
300
+ )
301
+ )
302
+ self.resnets = nn.ModuleList(resnets)
303
+
304
+ self.downsamplers = None
305
+ if spatio_temporal_scale:
306
+ self.downsamplers = nn.ModuleList(
307
+ [
308
+ LTXVideoCausalConv3d(
309
+ in_channels=in_channels,
310
+ out_channels=in_channels,
311
+ kernel_size=3,
312
+ stride=(2, 2, 2),
313
+ is_causal=is_causal,
314
+ )
315
+ ]
316
+ )
317
+
318
+ self.conv_out = None
319
+ if in_channels != out_channels:
320
+ self.conv_out = LTXVideoResnetBlock3d(
321
+ in_channels=in_channels,
322
+ out_channels=out_channels,
323
+ dropout=dropout,
324
+ eps=resnet_eps,
325
+ non_linearity=resnet_act_fn,
326
+ is_causal=is_causal,
327
+ )
328
+
329
+ self.gradient_checkpointing = False
330
+
331
+ def forward(
332
+ self,
333
+ hidden_states: torch.Tensor,
334
+ temb: Optional[torch.Tensor] = None,
335
+ generator: Optional[torch.Generator] = None,
336
+ ) -> torch.Tensor:
337
+ r"""Forward method of the `LTXDownBlock3D` class."""
338
+
339
+ for i, resnet in enumerate(self.resnets):
340
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
341
+
342
+ def create_custom_forward(module):
343
+ def create_forward(*inputs):
344
+ return module(*inputs)
345
+
346
+ return create_forward
347
+
348
+ hidden_states = torch.utils.checkpoint.checkpoint(
349
+ create_custom_forward(resnet), hidden_states, temb, generator
350
+ )
351
+ else:
352
+ hidden_states = resnet(hidden_states, temb, generator)
353
+
354
+ if self.downsamplers is not None:
355
+ for downsampler in self.downsamplers:
356
+ hidden_states = downsampler(hidden_states)
357
+
358
+ if self.conv_out is not None:
359
+ hidden_states = self.conv_out(hidden_states, temb, generator)
360
+
361
+ return hidden_states
362
+
363
+
364
+ # Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
365
+ class LTXVideoMidBlock3d(nn.Module):
366
+ r"""
367
+ A middle block used in the LTXVideo model.
368
+
369
+ Args:
370
+ in_channels (`int`):
371
+ Number of input channels.
372
+ num_layers (`int`, defaults to `1`):
373
+ Number of resnet layers.
374
+ dropout (`float`, defaults to `0.0`):
375
+ Dropout rate.
376
+ resnet_eps (`float`, defaults to `1e-6`):
377
+ Epsilon value for normalization layers.
378
+ resnet_act_fn (`str`, defaults to `"swish"`):
379
+ Activation function to use.
380
+ is_causal (`bool`, defaults to `True`):
381
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
382
+ """
383
+
384
+ _supports_gradient_checkpointing = True
385
+
386
+ def __init__(
387
+ self,
388
+ in_channels: int,
389
+ num_layers: int = 1,
390
+ dropout: float = 0.0,
391
+ resnet_eps: float = 1e-6,
392
+ resnet_act_fn: str = "swish",
393
+ is_causal: bool = True,
394
+ inject_noise: bool = False,
395
+ timestep_conditioning: bool = False,
396
+ ) -> None:
397
+ super().__init__()
398
+
399
+ self.time_embedder = None
400
+ if timestep_conditioning:
401
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
402
+
403
+ resnets = []
404
+ for _ in range(num_layers):
405
+ resnets.append(
406
+ LTXVideoResnetBlock3d(
407
+ in_channels=in_channels,
408
+ out_channels=in_channels,
409
+ dropout=dropout,
410
+ eps=resnet_eps,
411
+ non_linearity=resnet_act_fn,
412
+ is_causal=is_causal,
413
+ inject_noise=inject_noise,
414
+ timestep_conditioning=timestep_conditioning,
415
+ )
416
+ )
417
+ self.resnets = nn.ModuleList(resnets)
418
+
419
+ self.gradient_checkpointing = False
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ temb: Optional[torch.Tensor] = None,
425
+ generator: Optional[torch.Generator] = None,
426
+ ) -> torch.Tensor:
427
+ r"""Forward method of the `LTXMidBlock3D` class."""
428
+
429
+ if self.time_embedder is not None:
430
+ temb = self.time_embedder(
431
+ timestep=temb.flatten(),
432
+ resolution=None,
433
+ aspect_ratio=None,
434
+ batch_size=hidden_states.size(0),
435
+ hidden_dtype=hidden_states.dtype,
436
+ )
437
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
438
+
439
+ for i, resnet in enumerate(self.resnets):
440
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
441
+
442
+ def create_custom_forward(module):
443
+ def create_forward(*inputs):
444
+ return module(*inputs)
445
+
446
+ return create_forward
447
+
448
+ hidden_states = torch.utils.checkpoint.checkpoint(
449
+ create_custom_forward(resnet), hidden_states, temb, generator
450
+ )
451
+ else:
452
+ hidden_states = resnet(hidden_states, temb, generator)
453
+
454
+ return hidden_states
455
+
456
+
457
+ class LTXVideoUpBlock3d(nn.Module):
458
+ r"""
459
+ Up block used in the LTXVideo model.
460
+
461
+ Args:
462
+ in_channels (`int`):
463
+ Number of input channels.
464
+ out_channels (`int`, *optional*):
465
+ Number of output channels. If None, defaults to `in_channels`.
466
+ num_layers (`int`, defaults to `1`):
467
+ Number of resnet layers.
468
+ dropout (`float`, defaults to `0.0`):
469
+ Dropout rate.
470
+ resnet_eps (`float`, defaults to `1e-6`):
471
+ Epsilon value for normalization layers.
472
+ resnet_act_fn (`str`, defaults to `"swish"`):
473
+ Activation function to use.
474
+ spatio_temporal_scale (`bool`, defaults to `True`):
475
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
476
+ Whether or not to downsample across temporal dimension.
477
+ is_causal (`bool`, defaults to `True`):
478
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
479
+ """
480
+
481
+ _supports_gradient_checkpointing = True
482
+
483
+ def __init__(
484
+ self,
485
+ in_channels: int,
486
+ out_channels: Optional[int] = None,
487
+ num_layers: int = 1,
488
+ dropout: float = 0.0,
489
+ resnet_eps: float = 1e-6,
490
+ resnet_act_fn: str = "swish",
491
+ spatio_temporal_scale: bool = True,
492
+ is_causal: bool = True,
493
+ inject_noise: bool = False,
494
+ timestep_conditioning: bool = False,
495
+ upsample_residual: bool = False,
496
+ upscale_factor: int = 1,
497
+ ):
498
+ super().__init__()
499
+
500
+ out_channels = out_channels or in_channels
501
+
502
+ self.time_embedder = None
503
+ if timestep_conditioning:
504
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
505
+
506
+ self.conv_in = None
507
+ if in_channels != out_channels:
508
+ self.conv_in = LTXVideoResnetBlock3d(
509
+ in_channels=in_channels,
510
+ out_channels=out_channels,
511
+ dropout=dropout,
512
+ eps=resnet_eps,
513
+ non_linearity=resnet_act_fn,
514
+ is_causal=is_causal,
515
+ inject_noise=inject_noise,
516
+ timestep_conditioning=timestep_conditioning,
517
+ )
518
+
519
+ self.upsamplers = None
520
+ if spatio_temporal_scale:
521
+ self.upsamplers = nn.ModuleList(
522
+ [
523
+ LTXVideoUpsampler3d(
524
+ out_channels * upscale_factor,
525
+ stride=(2, 2, 2),
526
+ is_causal=is_causal,
527
+ residual=upsample_residual,
528
+ upscale_factor=upscale_factor,
529
+ )
530
+ ]
531
+ )
532
+
533
+ resnets = []
534
+ for _ in range(num_layers):
535
+ resnets.append(
536
+ LTXVideoResnetBlock3d(
537
+ in_channels=out_channels,
538
+ out_channels=out_channels,
539
+ dropout=dropout,
540
+ eps=resnet_eps,
541
+ non_linearity=resnet_act_fn,
542
+ is_causal=is_causal,
543
+ inject_noise=inject_noise,
544
+ timestep_conditioning=timestep_conditioning,
545
+ )
546
+ )
547
+ self.resnets = nn.ModuleList(resnets)
548
+
549
+ self.gradient_checkpointing = False
550
+
551
+ def forward(
552
+ self,
553
+ hidden_states: torch.Tensor,
554
+ temb: Optional[torch.Tensor] = None,
555
+ generator: Optional[torch.Generator] = None,
556
+ ) -> torch.Tensor:
557
+ if self.conv_in is not None:
558
+ hidden_states = self.conv_in(hidden_states, temb, generator)
559
+
560
+ if self.time_embedder is not None:
561
+ temb = self.time_embedder(
562
+ timestep=temb.flatten(),
563
+ resolution=None,
564
+ aspect_ratio=None,
565
+ batch_size=hidden_states.size(0),
566
+ hidden_dtype=hidden_states.dtype,
567
+ )
568
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
569
+
570
+ if self.upsamplers is not None:
571
+ for upsampler in self.upsamplers:
572
+ hidden_states = upsampler(hidden_states)
573
+
574
+ for i, resnet in enumerate(self.resnets):
575
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
576
+
577
+ def create_custom_forward(module):
578
+ def create_forward(*inputs):
579
+ return module(*inputs)
580
+
581
+ return create_forward
582
+
583
+ hidden_states = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(resnet), hidden_states, temb, generator
585
+ )
586
+ else:
587
+ hidden_states = resnet(hidden_states, temb, generator)
588
+
589
+ return hidden_states
590
+
591
+
592
+ class LTXVideoEncoder3d(nn.Module):
593
+ r"""
594
+ The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
595
+ representation.
596
+
597
+ Args:
598
+ in_channels (`int`, defaults to 3):
599
+ Number of input channels.
600
+ out_channels (`int`, defaults to 128):
601
+ Number of latent channels.
602
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
603
+ The number of output channels for each block.
604
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
605
+ Whether a block should contain spatio-temporal downscaling layers or not.
606
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
607
+ The number of layers per block.
608
+ patch_size (`int`, defaults to `4`):
609
+ The size of spatial patches.
610
+ patch_size_t (`int`, defaults to `1`):
611
+ The size of temporal patches.
612
+ resnet_norm_eps (`float`, defaults to `1e-6`):
613
+ Epsilon value for ResNet normalization layers.
614
+ is_causal (`bool`, defaults to `True`):
615
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
616
+ """
617
+
618
+ def __init__(
619
+ self,
620
+ in_channels: int = 3,
621
+ out_channels: int = 128,
622
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
623
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
624
+ layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
625
+ patch_size: int = 4,
626
+ patch_size_t: int = 1,
627
+ resnet_norm_eps: float = 1e-6,
628
+ is_causal: bool = True,
629
+ ):
630
+ super().__init__()
631
+
632
+ self.patch_size = patch_size
633
+ self.patch_size_t = patch_size_t
634
+ self.in_channels = in_channels * patch_size**2
635
+
636
+ output_channel = block_out_channels[0]
637
+
638
+ self.conv_in = LTXVideoCausalConv3d(
639
+ in_channels=self.in_channels,
640
+ out_channels=output_channel,
641
+ kernel_size=3,
642
+ stride=1,
643
+ is_causal=is_causal,
644
+ )
645
+
646
+ # down blocks
647
+ num_block_out_channels = len(block_out_channels)
648
+ self.down_blocks = nn.ModuleList([])
649
+ for i in range(num_block_out_channels):
650
+ input_channel = output_channel
651
+ output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
652
+
653
+ down_block = LTXVideoDownBlock3D(
654
+ in_channels=input_channel,
655
+ out_channels=output_channel,
656
+ num_layers=layers_per_block[i],
657
+ resnet_eps=resnet_norm_eps,
658
+ spatio_temporal_scale=spatio_temporal_scaling[i],
659
+ is_causal=is_causal,
660
+ )
661
+
662
+ self.down_blocks.append(down_block)
663
+
664
+ # mid block
665
+ self.mid_block = LTXVideoMidBlock3d(
666
+ in_channels=output_channel,
667
+ num_layers=layers_per_block[-1],
668
+ resnet_eps=resnet_norm_eps,
669
+ is_causal=is_causal,
670
+ )
671
+
672
+ # out
673
+ self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
674
+ self.conv_act = nn.SiLU()
675
+ self.conv_out = LTXVideoCausalConv3d(
676
+ in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
677
+ )
678
+
679
+ self.gradient_checkpointing = False
680
+
681
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
682
+ r"""The forward method of the `LTXVideoEncoder3d` class."""
683
+
684
+ p = self.patch_size
685
+ p_t = self.patch_size_t
686
+
687
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
688
+ post_patch_num_frames = num_frames // p_t
689
+ post_patch_height = height // p
690
+ post_patch_width = width // p
691
+
692
+ hidden_states = hidden_states.reshape(
693
+ batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p
694
+ )
695
+ # Thanks for driving me insane with the weird patching order :(
696
+ hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)
697
+ hidden_states = self.conv_in(hidden_states)
698
+
699
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
700
+
701
+ def create_custom_forward(module):
702
+ def create_forward(*inputs):
703
+ return module(*inputs)
704
+
705
+ return create_forward
706
+
707
+ for down_block in self.down_blocks:
708
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states)
709
+
710
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
711
+ else:
712
+ for down_block in self.down_blocks:
713
+ hidden_states = down_block(hidden_states)
714
+
715
+ hidden_states = self.mid_block(hidden_states)
716
+
717
+ hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
718
+ hidden_states = self.conv_act(hidden_states)
719
+ hidden_states = self.conv_out(hidden_states)
720
+
721
+ last_channel = hidden_states[:, -1:]
722
+ last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)
723
+ hidden_states = torch.cat([hidden_states, last_channel], dim=1)
724
+
725
+ return hidden_states
726
+
727
+
728
+ class LTXVideoDecoder3d(nn.Module):
729
+ r"""
730
+ The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
731
+ sample.
732
+
733
+ Args:
734
+ in_channels (`int`, defaults to 128):
735
+ Number of latent channels.
736
+ out_channels (`int`, defaults to 3):
737
+ Number of output channels.
738
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
739
+ The number of output channels for each block.
740
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
741
+ Whether a block should contain spatio-temporal upscaling layers or not.
742
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
743
+ The number of layers per block.
744
+ patch_size (`int`, defaults to `4`):
745
+ The size of spatial patches.
746
+ patch_size_t (`int`, defaults to `1`):
747
+ The size of temporal patches.
748
+ resnet_norm_eps (`float`, defaults to `1e-6`):
749
+ Epsilon value for ResNet normalization layers.
750
+ is_causal (`bool`, defaults to `False`):
751
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
752
+ timestep_conditioning (`bool`, defaults to `False`):
753
+ Whether to condition the model on timesteps.
754
+ """
755
+
756
+ def __init__(
757
+ self,
758
+ in_channels: int = 128,
759
+ out_channels: int = 3,
760
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
761
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
762
+ layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
763
+ patch_size: int = 4,
764
+ patch_size_t: int = 1,
765
+ resnet_norm_eps: float = 1e-6,
766
+ is_causal: bool = False,
767
+ inject_noise: Tuple[bool, ...] = (False, False, False, False),
768
+ timestep_conditioning: bool = False,
769
+ upsample_residual: Tuple[bool, ...] = (False, False, False, False),
770
+ upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
771
+ ) -> None:
772
+ super().__init__()
773
+
774
+ self.patch_size = patch_size
775
+ self.patch_size_t = patch_size_t
776
+ self.out_channels = out_channels * patch_size**2
777
+
778
+ block_out_channels = tuple(reversed(block_out_channels))
779
+ spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
780
+ layers_per_block = tuple(reversed(layers_per_block))
781
+ inject_noise = tuple(reversed(inject_noise))
782
+ upsample_residual = tuple(reversed(upsample_residual))
783
+ upsample_factor = tuple(reversed(upsample_factor))
784
+ output_channel = block_out_channels[0]
785
+
786
+ self.conv_in = LTXVideoCausalConv3d(
787
+ in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
788
+ )
789
+
790
+ self.mid_block = LTXVideoMidBlock3d(
791
+ in_channels=output_channel,
792
+ num_layers=layers_per_block[0],
793
+ resnet_eps=resnet_norm_eps,
794
+ is_causal=is_causal,
795
+ inject_noise=inject_noise[0],
796
+ timestep_conditioning=timestep_conditioning,
797
+ )
798
+
799
+ # up blocks
800
+ num_block_out_channels = len(block_out_channels)
801
+ self.up_blocks = nn.ModuleList([])
802
+ for i in range(num_block_out_channels):
803
+ input_channel = output_channel // upsample_factor[i]
804
+ output_channel = block_out_channels[i] // upsample_factor[i]
805
+
806
+ up_block = LTXVideoUpBlock3d(
807
+ in_channels=input_channel,
808
+ out_channels=output_channel,
809
+ num_layers=layers_per_block[i + 1],
810
+ resnet_eps=resnet_norm_eps,
811
+ spatio_temporal_scale=spatio_temporal_scaling[i],
812
+ is_causal=is_causal,
813
+ inject_noise=inject_noise[i + 1],
814
+ timestep_conditioning=timestep_conditioning,
815
+ upsample_residual=upsample_residual[i],
816
+ upscale_factor=upsample_factor[i],
817
+ )
818
+
819
+ self.up_blocks.append(up_block)
820
+
821
+ # out
822
+ self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
823
+ self.conv_act = nn.SiLU()
824
+ self.conv_out = LTXVideoCausalConv3d(
825
+ in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
826
+ )
827
+
828
+ # timestep embedding
829
+ self.time_embedder = None
830
+ self.scale_shift_table = None
831
+ if timestep_conditioning:
832
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
833
+ self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
834
+
835
+ self.gradient_checkpointing = False
836
+
837
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
838
+ hidden_states = self.conv_in(hidden_states)
839
+
840
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
841
+
842
+ def create_custom_forward(module):
843
+ def create_forward(*inputs):
844
+ return module(*inputs)
845
+
846
+ return create_forward
847
+
848
+ hidden_states = torch.utils.checkpoint.checkpoint(
849
+ create_custom_forward(self.mid_block), hidden_states, temb
850
+ )
851
+
852
+ for up_block in self.up_blocks:
853
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
854
+ else:
855
+ hidden_states = self.mid_block(hidden_states, temb)
856
+
857
+ for up_block in self.up_blocks:
858
+ hidden_states = up_block(hidden_states, temb)
859
+
860
+ hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
861
+
862
+ if self.time_embedder is not None:
863
+ temb = self.time_embedder(
864
+ timestep=temb.flatten(),
865
+ resolution=None,
866
+ aspect_ratio=None,
867
+ batch_size=hidden_states.size(0),
868
+ hidden_dtype=hidden_states.dtype,
869
+ )
870
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
871
+ temb = temb + self.scale_shift_table[None, ..., None, None, None]
872
+ shift, scale = temb.unbind(dim=1)
873
+ hidden_states = hidden_states * (1 + scale) + shift
874
+
875
+ hidden_states = self.conv_act(hidden_states)
876
+ hidden_states = self.conv_out(hidden_states)
877
+
878
+ p = self.patch_size
879
+ p_t = self.patch_size_t
880
+
881
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
882
+ hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width)
883
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3)
884
+
885
+ return hidden_states
886
+
887
+
888
+ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
889
+ r"""
890
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
891
+ [LTX](https://huggingface.co/Lightricks/LTX-Video).
892
+
893
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
894
+ for all models (such as downloading or saving).
895
+
896
+ Args:
897
+ in_channels (`int`, defaults to `3`):
898
+ Number of input channels.
899
+ out_channels (`int`, defaults to `3`):
900
+ Number of output channels.
901
+ latent_channels (`int`, defaults to `128`):
902
+ Number of latent channels.
903
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
904
+ The number of output channels for each block.
905
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
906
+ Whether a block should contain spatio-temporal downscaling or not.
907
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
908
+ The number of layers per block.
909
+ patch_size (`int`, defaults to `4`):
910
+ The size of spatial patches.
911
+ patch_size_t (`int`, defaults to `1`):
912
+ The size of temporal patches.
913
+ resnet_norm_eps (`float`, defaults to `1e-6`):
914
+ Epsilon value for ResNet normalization layers.
915
+ scaling_factor (`float`, *optional*, defaults to `1.0`):
916
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
917
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
918
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
919
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
920
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
921
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
922
+ encoder_causal (`bool`, defaults to `True`):
923
+ Whether the encoder should behave causally (future frames depend only on past frames) or not.
924
+ decoder_causal (`bool`, defaults to `False`):
925
+ Whether the decoder should behave causally (future frames depend only on past frames) or not.
926
+ """
927
+
928
+ _supports_gradient_checkpointing = True
929
+
930
+ @register_to_config
931
+ def __init__(
932
+ self,
933
+ in_channels: int = 3,
934
+ out_channels: int = 3,
935
+ latent_channels: int = 128,
936
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
937
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
938
+ layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
939
+ decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
940
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
941
+ decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
942
+ decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
943
+ upsample_residual: Tuple[bool, ...] = (False, False, False, False),
944
+ upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
945
+ timestep_conditioning: bool = False,
946
+ patch_size: int = 4,
947
+ patch_size_t: int = 1,
948
+ resnet_norm_eps: float = 1e-6,
949
+ scaling_factor: float = 1.0,
950
+ encoder_causal: bool = True,
951
+ decoder_causal: bool = False,
952
+ ) -> None:
953
+ super().__init__()
954
+
955
+ self.encoder = LTXVideoEncoder3d(
956
+ in_channels=in_channels,
957
+ out_channels=latent_channels,
958
+ block_out_channels=block_out_channels,
959
+ spatio_temporal_scaling=spatio_temporal_scaling,
960
+ layers_per_block=layers_per_block,
961
+ patch_size=patch_size,
962
+ patch_size_t=patch_size_t,
963
+ resnet_norm_eps=resnet_norm_eps,
964
+ is_causal=encoder_causal,
965
+ )
966
+ self.decoder = LTXVideoDecoder3d(
967
+ in_channels=latent_channels,
968
+ out_channels=out_channels,
969
+ block_out_channels=decoder_block_out_channels,
970
+ spatio_temporal_scaling=decoder_spatio_temporal_scaling,
971
+ layers_per_block=decoder_layers_per_block,
972
+ patch_size=patch_size,
973
+ patch_size_t=patch_size_t,
974
+ resnet_norm_eps=resnet_norm_eps,
975
+ is_causal=decoder_causal,
976
+ timestep_conditioning=timestep_conditioning,
977
+ inject_noise=decoder_inject_noise,
978
+ upsample_residual=upsample_residual,
979
+ upsample_factor=upsample_factor,
980
+ )
981
+
982
+ latents_mean = torch.zeros((latent_channels,), requires_grad=False)
983
+ latents_std = torch.ones((latent_channels,), requires_grad=False)
984
+ self.register_buffer("latents_mean", latents_mean, persistent=True)
985
+ self.register_buffer("latents_std", latents_std, persistent=True)
986
+
987
+ self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
988
+ self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
989
+
990
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
991
+ # to perform decoding of a single video latent at a time.
992
+ self.use_slicing = False
993
+
994
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
995
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
996
+ # intermediate tiles together, the memory requirement can be lowered.
997
+ self.use_tiling = False
998
+
999
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
1000
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
1001
+ self.use_framewise_encoding = False
1002
+ self.use_framewise_decoding = False
1003
+
1004
+ # This can be configured based on the amount of GPU memory available.
1005
+ # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
1006
+ # Setting it to higher values results in higher memory usage.
1007
+ self.num_sample_frames_batch_size = 16
1008
+ self.num_latent_frames_batch_size = 2
1009
+
1010
+ # The minimal tile height and width for spatial tiling to be used
1011
+ self.tile_sample_min_height = 512
1012
+ self.tile_sample_min_width = 512
1013
+
1014
+ # The minimal distance between two spatial tiles
1015
+ self.tile_sample_stride_height = 448
1016
+ self.tile_sample_stride_width = 448
1017
+
1018
+ def _set_gradient_checkpointing(self, module, value=False):
1019
+ if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
1020
+ module.gradient_checkpointing = value
1021
+
1022
+ def enable_tiling(
1023
+ self,
1024
+ tile_sample_min_height: Optional[int] = None,
1025
+ tile_sample_min_width: Optional[int] = None,
1026
+ tile_sample_stride_height: Optional[float] = None,
1027
+ tile_sample_stride_width: Optional[float] = None,
1028
+ ) -> None:
1029
+ r"""
1030
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1031
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1032
+ processing larger images.
1033
+
1034
+ Args:
1035
+ tile_sample_min_height (`int`, *optional*):
1036
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1037
+ tile_sample_min_width (`int`, *optional*):
1038
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1039
+ tile_sample_stride_height (`int`, *optional*):
1040
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1041
+ no tiling artifacts produced across the height dimension.
1042
+ tile_sample_stride_width (`int`, *optional*):
1043
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
1044
+ artifacts produced across the width dimension.
1045
+ """
1046
+ self.use_tiling = True
1047
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1048
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1049
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
1050
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
1051
+
1052
+ def disable_tiling(self) -> None:
1053
+ r"""
1054
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1055
+ decoding in one step.
1056
+ """
1057
+ self.use_tiling = False
1058
+
1059
+ def enable_slicing(self) -> None:
1060
+ r"""
1061
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1062
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1063
+ """
1064
+ self.use_slicing = True
1065
+
1066
+ def disable_slicing(self) -> None:
1067
+ r"""
1068
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1069
+ decoding in one step.
1070
+ """
1071
+ self.use_slicing = False
1072
+
1073
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1074
+ batch_size, num_channels, num_frames, height, width = x.shape
1075
+
1076
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1077
+ return self.tiled_encode(x)
1078
+
1079
+ if self.use_framewise_encoding:
1080
+ # TODO(aryan): requires investigation
1081
+ raise NotImplementedError(
1082
+ "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1083
+ "quality issues caused by splitting inference across frame dimension. If you believe this "
1084
+ "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1085
+ )
1086
+ else:
1087
+ enc = self.encoder(x)
1088
+
1089
+ return enc
1090
+
1091
+ @apply_forward_hook
1092
+ def encode(
1093
+ self, x: torch.Tensor, return_dict: bool = True
1094
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1095
+ """
1096
+ Encode a batch of images into latents.
1097
+
1098
+ Args:
1099
+ x (`torch.Tensor`): Input batch of images.
1100
+ return_dict (`bool`, *optional*, defaults to `True`):
1101
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1102
+
1103
+ Returns:
1104
+ The latent representations of the encoded videos. If `return_dict` is True, a
1105
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1106
+ """
1107
+ if self.use_slicing and x.shape[0] > 1:
1108
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1109
+ h = torch.cat(encoded_slices)
1110
+ else:
1111
+ h = self._encode(x)
1112
+ posterior = DiagonalGaussianDistribution(h)
1113
+
1114
+ if not return_dict:
1115
+ return (posterior,)
1116
+ return AutoencoderKLOutput(latent_dist=posterior)
1117
+
1118
+ def _decode(
1119
+ self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
1120
+ ) -> Union[DecoderOutput, torch.Tensor]:
1121
+ batch_size, num_channels, num_frames, height, width = z.shape
1122
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1123
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1124
+
1125
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
1126
+ return self.tiled_decode(z, temb, return_dict=return_dict)
1127
+
1128
+ if self.use_framewise_decoding:
1129
+ # TODO(aryan): requires investigation
1130
+ raise NotImplementedError(
1131
+ "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1132
+ "quality issues caused by splitting inference across frame dimension. If you believe this "
1133
+ "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1134
+ )
1135
+ else:
1136
+ dec = self.decoder(z, temb)
1137
+
1138
+ if not return_dict:
1139
+ return (dec,)
1140
+
1141
+ return DecoderOutput(sample=dec)
1142
+
1143
+ @apply_forward_hook
1144
+ def decode(
1145
+ self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
1146
+ ) -> Union[DecoderOutput, torch.Tensor]:
1147
+ """
1148
+ Decode a batch of images.
1149
+
1150
+ Args:
1151
+ z (`torch.Tensor`): Input batch of latent vectors.
1152
+ return_dict (`bool`, *optional*, defaults to `True`):
1153
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1154
+
1155
+ Returns:
1156
+ [`~models.vae.DecoderOutput`] or `tuple`:
1157
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1158
+ returned.
1159
+ """
1160
+ if self.use_slicing and z.shape[0] > 1:
1161
+ if temb is not None:
1162
+ decoded_slices = [
1163
+ self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
1164
+ ]
1165
+ else:
1166
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1167
+ decoded = torch.cat(decoded_slices)
1168
+ else:
1169
+ decoded = self._decode(z, temb).sample
1170
+
1171
+ if not return_dict:
1172
+ return (decoded,)
1173
+
1174
+ return DecoderOutput(sample=decoded)
1175
+
1176
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1177
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1178
+ for y in range(blend_extent):
1179
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1180
+ y / blend_extent
1181
+ )
1182
+ return b
1183
+
1184
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1185
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1186
+ for x in range(blend_extent):
1187
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1188
+ x / blend_extent
1189
+ )
1190
+ return b
1191
+
1192
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1193
+ r"""Encode a batch of images using a tiled encoder.
1194
+
1195
+ Args:
1196
+ x (`torch.Tensor`): Input batch of videos.
1197
+
1198
+ Returns:
1199
+ `torch.Tensor`:
1200
+ The latent representation of the encoded videos.
1201
+ """
1202
+ batch_size, num_channels, num_frames, height, width = x.shape
1203
+ latent_height = height // self.spatial_compression_ratio
1204
+ latent_width = width // self.spatial_compression_ratio
1205
+
1206
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1207
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1208
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1209
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1210
+
1211
+ blend_height = tile_latent_min_height - tile_latent_stride_height
1212
+ blend_width = tile_latent_min_width - tile_latent_stride_width
1213
+
1214
+ # Split x into overlapping tiles and encode them separately.
1215
+ # The tiles have an overlap to avoid seams between tiles.
1216
+ rows = []
1217
+ for i in range(0, height, self.tile_sample_stride_height):
1218
+ row = []
1219
+ for j in range(0, width, self.tile_sample_stride_width):
1220
+ if self.use_framewise_encoding:
1221
+ # TODO(aryan): requires investigation
1222
+ raise NotImplementedError(
1223
+ "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1224
+ "quality issues caused by splitting inference across frame dimension. If you believe this "
1225
+ "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1226
+ )
1227
+ else:
1228
+ time = self.encoder(
1229
+ x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1230
+ )
1231
+
1232
+ row.append(time)
1233
+ rows.append(row)
1234
+
1235
+ result_rows = []
1236
+ for i, row in enumerate(rows):
1237
+ result_row = []
1238
+ for j, tile in enumerate(row):
1239
+ # blend the above tile and the left tile
1240
+ # to the current tile and add the current tile to the result row
1241
+ if i > 0:
1242
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1243
+ if j > 0:
1244
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1245
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1246
+ result_rows.append(torch.cat(result_row, dim=4))
1247
+
1248
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1249
+ return enc
1250
+
1251
+ def tiled_decode(
1252
+ self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
1253
+ ) -> Union[DecoderOutput, torch.Tensor]:
1254
+ r"""
1255
+ Decode a batch of images using a tiled decoder.
1256
+
1257
+ Args:
1258
+ z (`torch.Tensor`): Input batch of latent vectors.
1259
+ return_dict (`bool`, *optional*, defaults to `True`):
1260
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1261
+
1262
+ Returns:
1263
+ [`~models.vae.DecoderOutput`] or `tuple`:
1264
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1265
+ returned.
1266
+ """
1267
+
1268
+ batch_size, num_channels, num_frames, height, width = z.shape
1269
+ sample_height = height * self.spatial_compression_ratio
1270
+ sample_width = width * self.spatial_compression_ratio
1271
+
1272
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1273
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1274
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1275
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1276
+
1277
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1278
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1279
+
1280
+ # Split z into overlapping tiles and decode them separately.
1281
+ # The tiles have an overlap to avoid seams between tiles.
1282
+ rows = []
1283
+ for i in range(0, height, tile_latent_stride_height):
1284
+ row = []
1285
+ for j in range(0, width, tile_latent_stride_width):
1286
+ if self.use_framewise_decoding:
1287
+ # TODO(aryan): requires investigation
1288
+ raise NotImplementedError(
1289
+ "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
1290
+ "quality issues caused by splitting inference across frame dimension. If you believe this "
1291
+ "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1292
+ )
1293
+ else:
1294
+ time = self.decoder(
1295
+ z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
1296
+ )
1297
+
1298
+ row.append(time)
1299
+ rows.append(row)
1300
+
1301
+ result_rows = []
1302
+ for i, row in enumerate(rows):
1303
+ result_row = []
1304
+ for j, tile in enumerate(row):
1305
+ # blend the above tile and the left tile
1306
+ # to the current tile and add the current tile to the result row
1307
+ if i > 0:
1308
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1309
+ if j > 0:
1310
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1311
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1312
+ result_rows.append(torch.cat(result_row, dim=4))
1313
+
1314
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1315
+
1316
+ if not return_dict:
1317
+ return (dec,)
1318
+
1319
+ return DecoderOutput(sample=dec)
1320
+
1321
+ def forward(
1322
+ self,
1323
+ sample: torch.Tensor,
1324
+ temb: Optional[torch.Tensor] = None,
1325
+ sample_posterior: bool = False,
1326
+ return_dict: bool = True,
1327
+ generator: Optional[torch.Generator] = None,
1328
+ ) -> Union[torch.Tensor, torch.Tensor]:
1329
+ x = sample
1330
+ posterior = self.encode(x).latent_dist
1331
+ if sample_posterior:
1332
+ z = posterior.sample(generator=generator)
1333
+ else:
1334
+ z = posterior.mode()
1335
+ dec = self.decode(z, temb)
1336
+ if not return_dict:
1337
+ return (dec,)
1338
+ return dec
icedit/diffusers/models/autoencoders/autoencoder_kl_mochi.py ADDED
@@ -0,0 +1,1166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Mochi team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ from typing import Dict, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import logging
24
+ from ...utils.accelerate_utils import apply_forward_hook
25
+ from ..activations import get_activation
26
+ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
27
+ from ..modeling_outputs import AutoencoderKLOutput
28
+ from ..modeling_utils import ModelMixin
29
+ from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
30
+ from .vae import DecoderOutput, DiagonalGaussianDistribution
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class MochiChunkedGroupNorm3D(nn.Module):
37
+ r"""
38
+ Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group
39
+ normalization.
40
+
41
+ Args:
42
+ num_channels (int): Number of channels expected in input
43
+ num_groups (int, optional): Number of groups to separate the channels into. Default: 32
44
+ affine (bool, optional): If True, this module has learnable affine parameters. Default: True
45
+ chunk_size (int, optional): Size of each chunk for processing. Default: 8
46
+
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ num_channels: int,
52
+ num_groups: int = 32,
53
+ affine: bool = True,
54
+ chunk_size: int = 8,
55
+ ):
56
+ super().__init__()
57
+ self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine)
58
+ self.chunk_size = chunk_size
59
+
60
+ def forward(self, x: torch.Tensor = None) -> torch.Tensor:
61
+ batch_size = x.size(0)
62
+
63
+ x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
64
+ output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0)
65
+ output = output.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
66
+
67
+ return output
68
+
69
+
70
+ class MochiResnetBlock3D(nn.Module):
71
+ r"""
72
+ A 3D ResNet block used in the Mochi model.
73
+
74
+ Args:
75
+ in_channels (`int`):
76
+ Number of input channels.
77
+ out_channels (`int`, *optional*):
78
+ Number of output channels. If None, defaults to `in_channels`.
79
+ non_linearity (`str`, defaults to `"swish"`):
80
+ Activation function to use.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: Optional[int] = None,
87
+ act_fn: str = "swish",
88
+ ):
89
+ super().__init__()
90
+
91
+ out_channels = out_channels or in_channels
92
+
93
+ self.in_channels = in_channels
94
+ self.out_channels = out_channels
95
+ self.nonlinearity = get_activation(act_fn)
96
+
97
+ self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels)
98
+ self.conv1 = CogVideoXCausalConv3d(
99
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate"
100
+ )
101
+ self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels)
102
+ self.conv2 = CogVideoXCausalConv3d(
103
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate"
104
+ )
105
+
106
+ def forward(
107
+ self,
108
+ inputs: torch.Tensor,
109
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
110
+ ) -> torch.Tensor:
111
+ new_conv_cache = {}
112
+ conv_cache = conv_cache or {}
113
+
114
+ hidden_states = inputs
115
+
116
+ hidden_states = self.norm1(hidden_states)
117
+ hidden_states = self.nonlinearity(hidden_states)
118
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
119
+
120
+ hidden_states = self.norm2(hidden_states)
121
+ hidden_states = self.nonlinearity(hidden_states)
122
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
123
+
124
+ hidden_states = hidden_states + inputs
125
+ return hidden_states, new_conv_cache
126
+
127
+
128
+ class MochiDownBlock3D(nn.Module):
129
+ r"""
130
+ An downsampling block used in the Mochi model.
131
+
132
+ Args:
133
+ in_channels (`int`):
134
+ Number of input channels.
135
+ out_channels (`int`, *optional*):
136
+ Number of output channels. If None, defaults to `in_channels`.
137
+ num_layers (`int`, defaults to `1`):
138
+ Number of resnet blocks in the block.
139
+ temporal_expansion (`int`, defaults to `2`):
140
+ Temporal expansion factor.
141
+ spatial_expansion (`int`, defaults to `2`):
142
+ Spatial expansion factor.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ in_channels: int,
148
+ out_channels: int,
149
+ num_layers: int = 1,
150
+ temporal_expansion: int = 2,
151
+ spatial_expansion: int = 2,
152
+ add_attention: bool = True,
153
+ ):
154
+ super().__init__()
155
+ self.temporal_expansion = temporal_expansion
156
+ self.spatial_expansion = spatial_expansion
157
+
158
+ self.conv_in = CogVideoXCausalConv3d(
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ kernel_size=(temporal_expansion, spatial_expansion, spatial_expansion),
162
+ stride=(temporal_expansion, spatial_expansion, spatial_expansion),
163
+ pad_mode="replicate",
164
+ )
165
+
166
+ resnets = []
167
+ norms = []
168
+ attentions = []
169
+ for _ in range(num_layers):
170
+ resnets.append(MochiResnetBlock3D(in_channels=out_channels))
171
+ if add_attention:
172
+ norms.append(MochiChunkedGroupNorm3D(num_channels=out_channels))
173
+ attentions.append(
174
+ Attention(
175
+ query_dim=out_channels,
176
+ heads=out_channels // 32,
177
+ dim_head=32,
178
+ qk_norm="l2",
179
+ is_causal=True,
180
+ processor=MochiVaeAttnProcessor2_0(),
181
+ )
182
+ )
183
+ else:
184
+ norms.append(None)
185
+ attentions.append(None)
186
+
187
+ self.resnets = nn.ModuleList(resnets)
188
+ self.norms = nn.ModuleList(norms)
189
+ self.attentions = nn.ModuleList(attentions)
190
+
191
+ self.gradient_checkpointing = False
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
197
+ chunk_size: int = 2**15,
198
+ ) -> torch.Tensor:
199
+ r"""Forward method of the `MochiUpBlock3D` class."""
200
+
201
+ new_conv_cache = {}
202
+ conv_cache = conv_cache or {}
203
+
204
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(hidden_states)
205
+
206
+ for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
207
+ conv_cache_key = f"resnet_{i}"
208
+
209
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
210
+
211
+ def create_custom_forward(module):
212
+ def create_forward(*inputs):
213
+ return module(*inputs)
214
+
215
+ return create_forward
216
+
217
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
218
+ create_custom_forward(resnet),
219
+ hidden_states,
220
+ conv_cache=conv_cache.get(conv_cache_key),
221
+ )
222
+ else:
223
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
224
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
225
+ )
226
+
227
+ if attn is not None:
228
+ residual = hidden_states
229
+ hidden_states = norm(hidden_states)
230
+
231
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
232
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous()
233
+
234
+ # Perform attention in chunks to avoid following error:
235
+ # RuntimeError: CUDA error: invalid configuration argument
236
+ if hidden_states.size(0) <= chunk_size:
237
+ hidden_states = attn(hidden_states)
238
+ else:
239
+ hidden_states_chunks = []
240
+ for i in range(0, hidden_states.size(0), chunk_size):
241
+ hidden_states_chunk = hidden_states[i : i + chunk_size]
242
+ hidden_states_chunk = attn(hidden_states_chunk)
243
+ hidden_states_chunks.append(hidden_states_chunk)
244
+ hidden_states = torch.cat(hidden_states_chunks)
245
+
246
+ hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2)
247
+
248
+ hidden_states = residual + hidden_states
249
+
250
+ return hidden_states, new_conv_cache
251
+
252
+
253
+ class MochiMidBlock3D(nn.Module):
254
+ r"""
255
+ A middle block used in the Mochi model.
256
+
257
+ Args:
258
+ in_channels (`int`):
259
+ Number of input channels.
260
+ num_layers (`int`, defaults to `3`):
261
+ Number of resnet blocks in the block.
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ in_channels: int, # 768
267
+ num_layers: int = 3,
268
+ add_attention: bool = True,
269
+ ):
270
+ super().__init__()
271
+
272
+ resnets = []
273
+ norms = []
274
+ attentions = []
275
+
276
+ for _ in range(num_layers):
277
+ resnets.append(MochiResnetBlock3D(in_channels=in_channels))
278
+
279
+ if add_attention:
280
+ norms.append(MochiChunkedGroupNorm3D(num_channels=in_channels))
281
+ attentions.append(
282
+ Attention(
283
+ query_dim=in_channels,
284
+ heads=in_channels // 32,
285
+ dim_head=32,
286
+ qk_norm="l2",
287
+ is_causal=True,
288
+ processor=MochiVaeAttnProcessor2_0(),
289
+ )
290
+ )
291
+ else:
292
+ norms.append(None)
293
+ attentions.append(None)
294
+
295
+ self.resnets = nn.ModuleList(resnets)
296
+ self.norms = nn.ModuleList(norms)
297
+ self.attentions = nn.ModuleList(attentions)
298
+
299
+ self.gradient_checkpointing = False
300
+
301
+ def forward(
302
+ self,
303
+ hidden_states: torch.Tensor,
304
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
305
+ ) -> torch.Tensor:
306
+ r"""Forward method of the `MochiMidBlock3D` class."""
307
+
308
+ new_conv_cache = {}
309
+ conv_cache = conv_cache or {}
310
+
311
+ for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
312
+ conv_cache_key = f"resnet_{i}"
313
+
314
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
315
+
316
+ def create_custom_forward(module):
317
+ def create_forward(*inputs):
318
+ return module(*inputs)
319
+
320
+ return create_forward
321
+
322
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
323
+ create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
324
+ )
325
+ else:
326
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
327
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
328
+ )
329
+
330
+ if attn is not None:
331
+ residual = hidden_states
332
+ hidden_states = norm(hidden_states)
333
+
334
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
335
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous()
336
+ hidden_states = attn(hidden_states)
337
+ hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2)
338
+
339
+ hidden_states = residual + hidden_states
340
+
341
+ return hidden_states, new_conv_cache
342
+
343
+
344
+ class MochiUpBlock3D(nn.Module):
345
+ r"""
346
+ An upsampling block used in the Mochi model.
347
+
348
+ Args:
349
+ in_channels (`int`):
350
+ Number of input channels.
351
+ out_channels (`int`, *optional*):
352
+ Number of output channels. If None, defaults to `in_channels`.
353
+ num_layers (`int`, defaults to `1`):
354
+ Number of resnet blocks in the block.
355
+ temporal_expansion (`int`, defaults to `2`):
356
+ Temporal expansion factor.
357
+ spatial_expansion (`int`, defaults to `2`):
358
+ Spatial expansion factor.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ in_channels: int,
364
+ out_channels: int,
365
+ num_layers: int = 1,
366
+ temporal_expansion: int = 2,
367
+ spatial_expansion: int = 2,
368
+ ):
369
+ super().__init__()
370
+ self.temporal_expansion = temporal_expansion
371
+ self.spatial_expansion = spatial_expansion
372
+
373
+ resnets = []
374
+ for _ in range(num_layers):
375
+ resnets.append(MochiResnetBlock3D(in_channels=in_channels))
376
+ self.resnets = nn.ModuleList(resnets)
377
+
378
+ self.proj = nn.Linear(in_channels, out_channels * temporal_expansion * spatial_expansion**2)
379
+
380
+ self.gradient_checkpointing = False
381
+
382
+ def forward(
383
+ self,
384
+ hidden_states: torch.Tensor,
385
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
386
+ ) -> torch.Tensor:
387
+ r"""Forward method of the `MochiUpBlock3D` class."""
388
+
389
+ new_conv_cache = {}
390
+ conv_cache = conv_cache or {}
391
+
392
+ for i, resnet in enumerate(self.resnets):
393
+ conv_cache_key = f"resnet_{i}"
394
+
395
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
396
+
397
+ def create_custom_forward(module):
398
+ def create_forward(*inputs):
399
+ return module(*inputs)
400
+
401
+ return create_forward
402
+
403
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
404
+ create_custom_forward(resnet),
405
+ hidden_states,
406
+ conv_cache=conv_cache.get(conv_cache_key),
407
+ )
408
+ else:
409
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
410
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
411
+ )
412
+
413
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
414
+ hidden_states = self.proj(hidden_states)
415
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
416
+
417
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
418
+ st = self.temporal_expansion
419
+ sh = self.spatial_expansion
420
+ sw = self.spatial_expansion
421
+
422
+ # Reshape and unpatchify
423
+ hidden_states = hidden_states.view(batch_size, -1, st, sh, sw, num_frames, height, width)
424
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
425
+ hidden_states = hidden_states.view(batch_size, -1, num_frames * st, height * sh, width * sw)
426
+
427
+ return hidden_states, new_conv_cache
428
+
429
+
430
+ class FourierFeatures(nn.Module):
431
+ def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
432
+ super().__init__()
433
+
434
+ self.start = start
435
+ self.stop = stop
436
+ self.step = step
437
+
438
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
439
+ r"""Forward method of the `FourierFeatures` class."""
440
+ original_dtype = inputs.dtype
441
+ inputs = inputs.to(torch.float32)
442
+ num_channels = inputs.shape[1]
443
+ num_freqs = (self.stop - self.start) // self.step
444
+
445
+ freqs = torch.arange(self.start, self.stop, self.step, dtype=inputs.dtype, device=inputs.device)
446
+ w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
447
+ w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
448
+
449
+ # Interleaved repeat of input channels to match w
450
+ h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
451
+ # Scale channels by frequency.
452
+ h = w * h
453
+
454
+ return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
455
+
456
+
457
+ class MochiEncoder3D(nn.Module):
458
+ r"""
459
+ The `MochiEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
460
+ representation.
461
+
462
+ Args:
463
+ in_channels (`int`, *optional*):
464
+ The number of input channels.
465
+ out_channels (`int`, *optional*):
466
+ The number of output channels.
467
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
468
+ The number of output channels for each block.
469
+ layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
470
+ The number of resnet blocks for each block.
471
+ temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
472
+ The temporal expansion factor for each of the up blocks.
473
+ spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
474
+ The spatial expansion factor for each of the up blocks.
475
+ non_linearity (`str`, *optional*, defaults to `"swish"`):
476
+ The non-linearity to use in the decoder.
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ in_channels: int,
482
+ out_channels: int,
483
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
484
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
485
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
486
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
487
+ add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
488
+ act_fn: str = "swish",
489
+ ):
490
+ super().__init__()
491
+
492
+ self.nonlinearity = get_activation(act_fn)
493
+
494
+ self.fourier_features = FourierFeatures()
495
+ self.proj_in = nn.Linear(in_channels, block_out_channels[0])
496
+ self.block_in = MochiMidBlock3D(
497
+ in_channels=block_out_channels[0], num_layers=layers_per_block[0], add_attention=add_attention_block[0]
498
+ )
499
+
500
+ down_blocks = []
501
+ for i in range(len(block_out_channels) - 1):
502
+ down_block = MochiDownBlock3D(
503
+ in_channels=block_out_channels[i],
504
+ out_channels=block_out_channels[i + 1],
505
+ num_layers=layers_per_block[i + 1],
506
+ temporal_expansion=temporal_expansions[i],
507
+ spatial_expansion=spatial_expansions[i],
508
+ add_attention=add_attention_block[i + 1],
509
+ )
510
+ down_blocks.append(down_block)
511
+ self.down_blocks = nn.ModuleList(down_blocks)
512
+
513
+ self.block_out = MochiMidBlock3D(
514
+ in_channels=block_out_channels[-1], num_layers=layers_per_block[-1], add_attention=add_attention_block[-1]
515
+ )
516
+ self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
517
+ self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)
518
+
519
+ def forward(
520
+ self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
521
+ ) -> torch.Tensor:
522
+ r"""Forward method of the `MochiEncoder3D` class."""
523
+
524
+ new_conv_cache = {}
525
+ conv_cache = conv_cache or {}
526
+
527
+ hidden_states = self.fourier_features(hidden_states)
528
+
529
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
530
+ hidden_states = self.proj_in(hidden_states)
531
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
532
+
533
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
534
+
535
+ def create_custom_forward(module):
536
+ def create_forward(*inputs):
537
+ return module(*inputs)
538
+
539
+ return create_forward
540
+
541
+ hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
543
+ )
544
+
545
+ for i, down_block in enumerate(self.down_blocks):
546
+ conv_cache_key = f"down_block_{i}"
547
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
548
+ create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
549
+ )
550
+ else:
551
+ hidden_states, new_conv_cache["block_in"] = self.block_in(
552
+ hidden_states, conv_cache=conv_cache.get("block_in")
553
+ )
554
+
555
+ for i, down_block in enumerate(self.down_blocks):
556
+ conv_cache_key = f"down_block_{i}"
557
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
558
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
559
+ )
560
+
561
+ hidden_states, new_conv_cache["block_out"] = self.block_out(
562
+ hidden_states, conv_cache=conv_cache.get("block_out")
563
+ )
564
+
565
+ hidden_states = self.norm_out(hidden_states)
566
+ hidden_states = self.nonlinearity(hidden_states)
567
+
568
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
569
+ hidden_states = self.proj_out(hidden_states)
570
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
571
+
572
+ return hidden_states, new_conv_cache
573
+
574
+
575
+ class MochiDecoder3D(nn.Module):
576
+ r"""
577
+ The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
578
+ sample.
579
+
580
+ Args:
581
+ in_channels (`int`, *optional*):
582
+ The number of input channels.
583
+ out_channels (`int`, *optional*):
584
+ The number of output channels.
585
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
586
+ The number of output channels for each block.
587
+ layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
588
+ The number of resnet blocks for each block.
589
+ temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
590
+ The temporal expansion factor for each of the up blocks.
591
+ spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
592
+ The spatial expansion factor for each of the up blocks.
593
+ non_linearity (`str`, *optional*, defaults to `"swish"`):
594
+ The non-linearity to use in the decoder.
595
+ """
596
+
597
+ def __init__(
598
+ self,
599
+ in_channels: int, # 12
600
+ out_channels: int, # 3
601
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
602
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
603
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
604
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
605
+ act_fn: str = "swish",
606
+ ):
607
+ super().__init__()
608
+
609
+ self.nonlinearity = get_activation(act_fn)
610
+
611
+ self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1))
612
+ self.block_in = MochiMidBlock3D(
613
+ in_channels=block_out_channels[-1],
614
+ num_layers=layers_per_block[-1],
615
+ add_attention=False,
616
+ )
617
+
618
+ up_blocks = []
619
+ for i in range(len(block_out_channels) - 1):
620
+ up_block = MochiUpBlock3D(
621
+ in_channels=block_out_channels[-i - 1],
622
+ out_channels=block_out_channels[-i - 2],
623
+ num_layers=layers_per_block[-i - 2],
624
+ temporal_expansion=temporal_expansions[-i - 1],
625
+ spatial_expansion=spatial_expansions[-i - 1],
626
+ )
627
+ up_blocks.append(up_block)
628
+ self.up_blocks = nn.ModuleList(up_blocks)
629
+
630
+ self.block_out = MochiMidBlock3D(
631
+ in_channels=block_out_channels[0],
632
+ num_layers=layers_per_block[0],
633
+ add_attention=False,
634
+ )
635
+ self.proj_out = nn.Linear(block_out_channels[0], out_channels)
636
+
637
+ self.gradient_checkpointing = False
638
+
639
+ def forward(
640
+ self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
641
+ ) -> torch.Tensor:
642
+ r"""Forward method of the `MochiDecoder3D` class."""
643
+
644
+ new_conv_cache = {}
645
+ conv_cache = conv_cache or {}
646
+
647
+ hidden_states = self.conv_in(hidden_states)
648
+
649
+ # 1. Mid
650
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
651
+
652
+ def create_custom_forward(module):
653
+ def create_forward(*inputs):
654
+ return module(*inputs)
655
+
656
+ return create_forward
657
+
658
+ hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint(
659
+ create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in")
660
+ )
661
+
662
+ for i, up_block in enumerate(self.up_blocks):
663
+ conv_cache_key = f"up_block_{i}"
664
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
665
+ create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key)
666
+ )
667
+ else:
668
+ hidden_states, new_conv_cache["block_in"] = self.block_in(
669
+ hidden_states, conv_cache=conv_cache.get("block_in")
670
+ )
671
+
672
+ for i, up_block in enumerate(self.up_blocks):
673
+ conv_cache_key = f"up_block_{i}"
674
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
675
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
676
+ )
677
+
678
+ hidden_states, new_conv_cache["block_out"] = self.block_out(
679
+ hidden_states, conv_cache=conv_cache.get("block_out")
680
+ )
681
+
682
+ hidden_states = self.nonlinearity(hidden_states)
683
+
684
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
685
+ hidden_states = self.proj_out(hidden_states)
686
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
687
+
688
+ return hidden_states, new_conv_cache
689
+
690
+
691
+ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
692
+ r"""
693
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
694
+ [Mochi 1 preview](https://github.com/genmoai/models).
695
+
696
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
697
+ for all models (such as downloading or saving).
698
+
699
+ Parameters:
700
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
701
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
702
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
703
+ Tuple of block output channels.
704
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
705
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
706
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
707
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
708
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
709
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
710
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
711
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
712
+ """
713
+
714
+ _supports_gradient_checkpointing = True
715
+ _no_split_modules = ["MochiResnetBlock3D"]
716
+
717
+ @register_to_config
718
+ def __init__(
719
+ self,
720
+ in_channels: int = 15,
721
+ out_channels: int = 3,
722
+ encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
723
+ decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
724
+ latent_channels: int = 12,
725
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
726
+ act_fn: str = "silu",
727
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
728
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
729
+ add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
730
+ latents_mean: Tuple[float, ...] = (
731
+ -0.06730895953510081,
732
+ -0.038011381506090416,
733
+ -0.07477820912866141,
734
+ -0.05565264470995561,
735
+ 0.012767231469026969,
736
+ -0.04703542746246419,
737
+ 0.043896967884726704,
738
+ -0.09346305707025976,
739
+ -0.09918314763016893,
740
+ -0.008729793427399178,
741
+ -0.011931556316503654,
742
+ -0.0321993391887285,
743
+ ),
744
+ latents_std: Tuple[float, ...] = (
745
+ 0.9263795028493863,
746
+ 0.9248894543193766,
747
+ 0.9393059390890617,
748
+ 0.959253732819592,
749
+ 0.8244560132752793,
750
+ 0.917259975397747,
751
+ 0.9294154431013696,
752
+ 1.3720942357788521,
753
+ 0.881393668867029,
754
+ 0.9168315692124348,
755
+ 0.9185249279345552,
756
+ 0.9274757570805041,
757
+ ),
758
+ scaling_factor: float = 1.0,
759
+ ):
760
+ super().__init__()
761
+
762
+ self.encoder = MochiEncoder3D(
763
+ in_channels=in_channels,
764
+ out_channels=latent_channels,
765
+ block_out_channels=encoder_block_out_channels,
766
+ layers_per_block=layers_per_block,
767
+ temporal_expansions=temporal_expansions,
768
+ spatial_expansions=spatial_expansions,
769
+ add_attention_block=add_attention_block,
770
+ act_fn=act_fn,
771
+ )
772
+ self.decoder = MochiDecoder3D(
773
+ in_channels=latent_channels,
774
+ out_channels=out_channels,
775
+ block_out_channels=decoder_block_out_channels,
776
+ layers_per_block=layers_per_block,
777
+ temporal_expansions=temporal_expansions,
778
+ spatial_expansions=spatial_expansions,
779
+ act_fn=act_fn,
780
+ )
781
+
782
+ self.spatial_compression_ratio = functools.reduce(lambda x, y: x * y, spatial_expansions, 1)
783
+ self.temporal_compression_ratio = functools.reduce(lambda x, y: x * y, temporal_expansions, 1)
784
+
785
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
786
+ # to perform decoding of a single video latent at a time.
787
+ self.use_slicing = False
788
+
789
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
790
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
791
+ # intermediate tiles together, the memory requirement can be lowered.
792
+ self.use_tiling = False
793
+
794
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
795
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
796
+ self.use_framewise_encoding = False
797
+ self.use_framewise_decoding = False
798
+
799
+ # This can be used to determine how the number of output frames in the final decoded video. To maintain consistency with
800
+ # the original implementation, this defaults to `True`.
801
+ # - Original implementation (drop_last_temporal_frames=True):
802
+ # Output frames = (latent_frames - 1) * temporal_compression_ratio + 1
803
+ # - Without dropping additional temporal upscaled frames (drop_last_temporal_frames=False):
804
+ # Output frames = latent_frames * temporal_compression_ratio
805
+ # The latter case is useful for frame packing and some training/finetuning scenarios where the additional.
806
+ self.drop_last_temporal_frames = True
807
+
808
+ # This can be configured based on the amount of GPU memory available.
809
+ # `12` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
810
+ # Setting it to higher values results in higher memory usage.
811
+ self.num_sample_frames_batch_size = 12
812
+ self.num_latent_frames_batch_size = 2
813
+
814
+ # The minimal tile height and width for spatial tiling to be used
815
+ self.tile_sample_min_height = 256
816
+ self.tile_sample_min_width = 256
817
+
818
+ # The minimal distance between two spatial tiles
819
+ self.tile_sample_stride_height = 192
820
+ self.tile_sample_stride_width = 192
821
+
822
+ def _set_gradient_checkpointing(self, module, value=False):
823
+ if isinstance(module, (MochiEncoder3D, MochiDecoder3D)):
824
+ module.gradient_checkpointing = value
825
+
826
+ def enable_tiling(
827
+ self,
828
+ tile_sample_min_height: Optional[int] = None,
829
+ tile_sample_min_width: Optional[int] = None,
830
+ tile_sample_stride_height: Optional[float] = None,
831
+ tile_sample_stride_width: Optional[float] = None,
832
+ ) -> None:
833
+ r"""
834
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
835
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
836
+ processing larger images.
837
+
838
+ Args:
839
+ tile_sample_min_height (`int`, *optional*):
840
+ The minimum height required for a sample to be separated into tiles across the height dimension.
841
+ tile_sample_min_width (`int`, *optional*):
842
+ The minimum width required for a sample to be separated into tiles across the width dimension.
843
+ tile_sample_stride_height (`int`, *optional*):
844
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
845
+ no tiling artifacts produced across the height dimension.
846
+ tile_sample_stride_width (`int`, *optional*):
847
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
848
+ artifacts produced across the width dimension.
849
+ """
850
+ self.use_tiling = True
851
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
852
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
853
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
854
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
855
+
856
+ def disable_tiling(self) -> None:
857
+ r"""
858
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
859
+ decoding in one step.
860
+ """
861
+ self.use_tiling = False
862
+
863
+ def enable_slicing(self) -> None:
864
+ r"""
865
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
866
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
867
+ """
868
+ self.use_slicing = True
869
+
870
+ def disable_slicing(self) -> None:
871
+ r"""
872
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
873
+ decoding in one step.
874
+ """
875
+ self.use_slicing = False
876
+
877
+ def _enable_framewise_encoding(self):
878
+ r"""
879
+ Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
880
+ oneshot encoding implementation without current latent replicate padding.
881
+
882
+ Warning: Framewise encoding may not work as expected due to the causal attention layers. If you enable
883
+ framewise encoding, encode a video, and try to decode it, there will be noticeable jittering effect.
884
+ """
885
+ self.use_framewise_encoding = True
886
+ for name, module in self.named_modules():
887
+ if isinstance(module, CogVideoXCausalConv3d):
888
+ module.pad_mode = "constant"
889
+
890
+ def _enable_framewise_decoding(self):
891
+ r"""
892
+ Enables the framewise VAE decoding implementation with past latent padding. By default, Diffusers uses the
893
+ oneshot decoding implementation without current latent replicate padding.
894
+ """
895
+ self.use_framewise_decoding = True
896
+ for name, module in self.named_modules():
897
+ if isinstance(module, CogVideoXCausalConv3d):
898
+ module.pad_mode = "constant"
899
+
900
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
901
+ batch_size, num_channels, num_frames, height, width = x.shape
902
+
903
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
904
+ return self.tiled_encode(x)
905
+
906
+ if self.use_framewise_encoding:
907
+ raise NotImplementedError(
908
+ "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. "
909
+ "As intermediate frames are not independent from each other, they cannot be encoded frame-wise."
910
+ )
911
+ else:
912
+ enc, _ = self.encoder(x)
913
+
914
+ return enc
915
+
916
+ @apply_forward_hook
917
+ def encode(
918
+ self, x: torch.Tensor, return_dict: bool = True
919
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
920
+ """
921
+ Encode a batch of images into latents.
922
+
923
+ Args:
924
+ x (`torch.Tensor`): Input batch of images.
925
+ return_dict (`bool`, *optional*, defaults to `True`):
926
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
927
+
928
+ Returns:
929
+ The latent representations of the encoded videos. If `return_dict` is True, a
930
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
931
+ """
932
+ if self.use_slicing and x.shape[0] > 1:
933
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
934
+ h = torch.cat(encoded_slices)
935
+ else:
936
+ h = self._encode(x)
937
+
938
+ posterior = DiagonalGaussianDistribution(h)
939
+
940
+ if not return_dict:
941
+ return (posterior,)
942
+ return AutoencoderKLOutput(latent_dist=posterior)
943
+
944
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
945
+ batch_size, num_channels, num_frames, height, width = z.shape
946
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
947
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
948
+
949
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
950
+ return self.tiled_decode(z, return_dict=return_dict)
951
+
952
+ if self.use_framewise_decoding:
953
+ conv_cache = None
954
+ dec = []
955
+
956
+ for i in range(0, num_frames, self.num_latent_frames_batch_size):
957
+ z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size]
958
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
959
+ dec.append(z_intermediate)
960
+
961
+ dec = torch.cat(dec, dim=2)
962
+ else:
963
+ dec, _ = self.decoder(z)
964
+
965
+ if self.drop_last_temporal_frames and dec.size(2) >= self.temporal_compression_ratio:
966
+ dec = dec[:, :, self.temporal_compression_ratio - 1 :]
967
+
968
+ if not return_dict:
969
+ return (dec,)
970
+
971
+ return DecoderOutput(sample=dec)
972
+
973
+ @apply_forward_hook
974
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
975
+ """
976
+ Decode a batch of images.
977
+
978
+ Args:
979
+ z (`torch.Tensor`): Input batch of latent vectors.
980
+ return_dict (`bool`, *optional*, defaults to `True`):
981
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
982
+
983
+ Returns:
984
+ [`~models.vae.DecoderOutput`] or `tuple`:
985
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
986
+ returned.
987
+ """
988
+ if self.use_slicing and z.shape[0] > 1:
989
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
990
+ decoded = torch.cat(decoded_slices)
991
+ else:
992
+ decoded = self._decode(z).sample
993
+
994
+ if not return_dict:
995
+ return (decoded,)
996
+
997
+ return DecoderOutput(sample=decoded)
998
+
999
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1000
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1001
+ for y in range(blend_extent):
1002
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1003
+ y / blend_extent
1004
+ )
1005
+ return b
1006
+
1007
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1008
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1009
+ for x in range(blend_extent):
1010
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1011
+ x / blend_extent
1012
+ )
1013
+ return b
1014
+
1015
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1016
+ r"""Encode a batch of images using a tiled encoder.
1017
+
1018
+ Args:
1019
+ x (`torch.Tensor`): Input batch of videos.
1020
+
1021
+ Returns:
1022
+ `torch.Tensor`:
1023
+ The latent representation of the encoded videos.
1024
+ """
1025
+ batch_size, num_channels, num_frames, height, width = x.shape
1026
+ latent_height = height // self.spatial_compression_ratio
1027
+ latent_width = width // self.spatial_compression_ratio
1028
+
1029
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1030
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1031
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1032
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1033
+
1034
+ blend_height = tile_latent_min_height - tile_latent_stride_height
1035
+ blend_width = tile_latent_min_width - tile_latent_stride_width
1036
+
1037
+ # Split x into overlapping tiles and encode them separately.
1038
+ # The tiles have an overlap to avoid seams between tiles.
1039
+ rows = []
1040
+ for i in range(0, height, self.tile_sample_stride_height):
1041
+ row = []
1042
+ for j in range(0, width, self.tile_sample_stride_width):
1043
+ if self.use_framewise_encoding:
1044
+ raise NotImplementedError(
1045
+ "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. "
1046
+ "As intermediate frames are not independent from each other, they cannot be encoded frame-wise."
1047
+ )
1048
+ else:
1049
+ time, _ = self.encoder(
1050
+ x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1051
+ )
1052
+
1053
+ row.append(time)
1054
+ rows.append(row)
1055
+
1056
+ result_rows = []
1057
+ for i, row in enumerate(rows):
1058
+ result_row = []
1059
+ for j, tile in enumerate(row):
1060
+ # blend the above tile and the left tile
1061
+ # to the current tile and add the current tile to the result row
1062
+ if i > 0:
1063
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1064
+ if j > 0:
1065
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1066
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1067
+ result_rows.append(torch.cat(result_row, dim=4))
1068
+
1069
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1070
+ return enc
1071
+
1072
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1073
+ r"""
1074
+ Decode a batch of images using a tiled decoder.
1075
+
1076
+ Args:
1077
+ z (`torch.Tensor`): Input batch of latent vectors.
1078
+ return_dict (`bool`, *optional*, defaults to `True`):
1079
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1080
+
1081
+ Returns:
1082
+ [`~models.vae.DecoderOutput`] or `tuple`:
1083
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1084
+ returned.
1085
+ """
1086
+
1087
+ batch_size, num_channels, num_frames, height, width = z.shape
1088
+ sample_height = height * self.spatial_compression_ratio
1089
+ sample_width = width * self.spatial_compression_ratio
1090
+
1091
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1092
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1093
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1094
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1095
+
1096
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1097
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1098
+
1099
+ # Split z into overlapping tiles and decode them separately.
1100
+ # The tiles have an overlap to avoid seams between tiles.
1101
+ rows = []
1102
+ for i in range(0, height, tile_latent_stride_height):
1103
+ row = []
1104
+ for j in range(0, width, tile_latent_stride_width):
1105
+ if self.use_framewise_decoding:
1106
+ time = []
1107
+ conv_cache = None
1108
+
1109
+ for k in range(0, num_frames, self.num_latent_frames_batch_size):
1110
+ tile = z[
1111
+ :,
1112
+ :,
1113
+ k : k + self.num_latent_frames_batch_size,
1114
+ i : i + tile_latent_min_height,
1115
+ j : j + tile_latent_min_width,
1116
+ ]
1117
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1118
+ time.append(tile)
1119
+
1120
+ time = torch.cat(time, dim=2)
1121
+ else:
1122
+ time, _ = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
1123
+
1124
+ if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio:
1125
+ time = time[:, :, self.temporal_compression_ratio - 1 :]
1126
+
1127
+ row.append(time)
1128
+ rows.append(row)
1129
+
1130
+ result_rows = []
1131
+ for i, row in enumerate(rows):
1132
+ result_row = []
1133
+ for j, tile in enumerate(row):
1134
+ # blend the above tile and the left tile
1135
+ # to the current tile and add the current tile to the result row
1136
+ if i > 0:
1137
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1138
+ if j > 0:
1139
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1140
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1141
+ result_rows.append(torch.cat(result_row, dim=4))
1142
+
1143
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1144
+
1145
+ if not return_dict:
1146
+ return (dec,)
1147
+
1148
+ return DecoderOutput(sample=dec)
1149
+
1150
+ def forward(
1151
+ self,
1152
+ sample: torch.Tensor,
1153
+ sample_posterior: bool = False,
1154
+ return_dict: bool = True,
1155
+ generator: Optional[torch.Generator] = None,
1156
+ ) -> Union[torch.Tensor, torch.Tensor]:
1157
+ x = sample
1158
+ posterior = self.encode(x).latent_dist
1159
+ if sample_posterior:
1160
+ z = posterior.sample(generator=generator)
1161
+ else:
1162
+ z = posterior.mode()
1163
+ dec = self.decode(z)
1164
+ if not return_dict:
1165
+ return (dec,)
1166
+ return dec
icedit/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import itertools
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import is_torch_version
22
+ from ...utils.accelerate_utils import apply_forward_hook
23
+ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
24
+ from ..modeling_outputs import AutoencoderKLOutput
25
+ from ..modeling_utils import ModelMixin
26
+ from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
27
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
28
+
29
+
30
+ class TemporalDecoder(nn.Module):
31
+ def __init__(
32
+ self,
33
+ in_channels: int = 4,
34
+ out_channels: int = 3,
35
+ block_out_channels: Tuple[int] = (128, 256, 512, 512),
36
+ layers_per_block: int = 2,
37
+ ):
38
+ super().__init__()
39
+ self.layers_per_block = layers_per_block
40
+
41
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
42
+ self.mid_block = MidBlockTemporalDecoder(
43
+ num_layers=self.layers_per_block,
44
+ in_channels=block_out_channels[-1],
45
+ out_channels=block_out_channels[-1],
46
+ attention_head_dim=block_out_channels[-1],
47
+ )
48
+
49
+ # up
50
+ self.up_blocks = nn.ModuleList([])
51
+ reversed_block_out_channels = list(reversed(block_out_channels))
52
+ output_channel = reversed_block_out_channels[0]
53
+ for i in range(len(block_out_channels)):
54
+ prev_output_channel = output_channel
55
+ output_channel = reversed_block_out_channels[i]
56
+
57
+ is_final_block = i == len(block_out_channels) - 1
58
+ up_block = UpBlockTemporalDecoder(
59
+ num_layers=self.layers_per_block + 1,
60
+ in_channels=prev_output_channel,
61
+ out_channels=output_channel,
62
+ add_upsample=not is_final_block,
63
+ )
64
+ self.up_blocks.append(up_block)
65
+ prev_output_channel = output_channel
66
+
67
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6)
68
+
69
+ self.conv_act = nn.SiLU()
70
+ self.conv_out = torch.nn.Conv2d(
71
+ in_channels=block_out_channels[0],
72
+ out_channels=out_channels,
73
+ kernel_size=3,
74
+ padding=1,
75
+ )
76
+
77
+ conv_out_kernel_size = (3, 1, 1)
78
+ padding = [int(k // 2) for k in conv_out_kernel_size]
79
+ self.time_conv_out = torch.nn.Conv3d(
80
+ in_channels=out_channels,
81
+ out_channels=out_channels,
82
+ kernel_size=conv_out_kernel_size,
83
+ padding=padding,
84
+ )
85
+
86
+ self.gradient_checkpointing = False
87
+
88
+ def forward(
89
+ self,
90
+ sample: torch.Tensor,
91
+ image_only_indicator: torch.Tensor,
92
+ num_frames: int = 1,
93
+ ) -> torch.Tensor:
94
+ r"""The forward method of the `Decoder` class."""
95
+
96
+ sample = self.conv_in(sample)
97
+
98
+ upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype
99
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
100
+
101
+ def create_custom_forward(module):
102
+ def custom_forward(*inputs):
103
+ return module(*inputs)
104
+
105
+ return custom_forward
106
+
107
+ if is_torch_version(">=", "1.11.0"):
108
+ # middle
109
+ sample = torch.utils.checkpoint.checkpoint(
110
+ create_custom_forward(self.mid_block),
111
+ sample,
112
+ image_only_indicator,
113
+ use_reentrant=False,
114
+ )
115
+ sample = sample.to(upscale_dtype)
116
+
117
+ # up
118
+ for up_block in self.up_blocks:
119
+ sample = torch.utils.checkpoint.checkpoint(
120
+ create_custom_forward(up_block),
121
+ sample,
122
+ image_only_indicator,
123
+ use_reentrant=False,
124
+ )
125
+ else:
126
+ # middle
127
+ sample = torch.utils.checkpoint.checkpoint(
128
+ create_custom_forward(self.mid_block),
129
+ sample,
130
+ image_only_indicator,
131
+ )
132
+ sample = sample.to(upscale_dtype)
133
+
134
+ # up
135
+ for up_block in self.up_blocks:
136
+ sample = torch.utils.checkpoint.checkpoint(
137
+ create_custom_forward(up_block),
138
+ sample,
139
+ image_only_indicator,
140
+ )
141
+ else:
142
+ # middle
143
+ sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
144
+ sample = sample.to(upscale_dtype)
145
+
146
+ # up
147
+ for up_block in self.up_blocks:
148
+ sample = up_block(sample, image_only_indicator=image_only_indicator)
149
+
150
+ # post-process
151
+ sample = self.conv_norm_out(sample)
152
+ sample = self.conv_act(sample)
153
+ sample = self.conv_out(sample)
154
+
155
+ batch_frames, channels, height, width = sample.shape
156
+ batch_size = batch_frames // num_frames
157
+ sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
158
+ sample = self.time_conv_out(sample)
159
+
160
+ sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
161
+
162
+ return sample
163
+
164
+
165
+ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
166
+ r"""
167
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
168
+
169
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
170
+ for all models (such as downloading or saving).
171
+
172
+ Parameters:
173
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
174
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
175
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
176
+ Tuple of downsample block types.
177
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
178
+ Tuple of block output channels.
179
+ layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
180
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
181
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
182
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
183
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
184
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
185
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
186
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
187
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
188
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
189
+ force_upcast (`bool`, *optional*, default to `True`):
190
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
191
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
192
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
193
+ """
194
+
195
+ _supports_gradient_checkpointing = True
196
+
197
+ @register_to_config
198
+ def __init__(
199
+ self,
200
+ in_channels: int = 3,
201
+ out_channels: int = 3,
202
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
203
+ block_out_channels: Tuple[int] = (64,),
204
+ layers_per_block: int = 1,
205
+ latent_channels: int = 4,
206
+ sample_size: int = 32,
207
+ scaling_factor: float = 0.18215,
208
+ force_upcast: float = True,
209
+ ):
210
+ super().__init__()
211
+
212
+ # pass init params to Encoder
213
+ self.encoder = Encoder(
214
+ in_channels=in_channels,
215
+ out_channels=latent_channels,
216
+ down_block_types=down_block_types,
217
+ block_out_channels=block_out_channels,
218
+ layers_per_block=layers_per_block,
219
+ double_z=True,
220
+ )
221
+
222
+ # pass init params to Decoder
223
+ self.decoder = TemporalDecoder(
224
+ in_channels=latent_channels,
225
+ out_channels=out_channels,
226
+ block_out_channels=block_out_channels,
227
+ layers_per_block=layers_per_block,
228
+ )
229
+
230
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
231
+
232
+ def _set_gradient_checkpointing(self, module, value=False):
233
+ if isinstance(module, (Encoder, TemporalDecoder)):
234
+ module.gradient_checkpointing = value
235
+
236
+ @property
237
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
238
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
239
+ r"""
240
+ Returns:
241
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
242
+ indexed by its weight name.
243
+ """
244
+ # set recursively
245
+ processors = {}
246
+
247
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
248
+ if hasattr(module, "get_processor"):
249
+ processors[f"{name}.processor"] = module.get_processor()
250
+
251
+ for sub_name, child in module.named_children():
252
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
253
+
254
+ return processors
255
+
256
+ for name, module in self.named_children():
257
+ fn_recursive_add_processors(name, module, processors)
258
+
259
+ return processors
260
+
261
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
262
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
263
+ r"""
264
+ Sets the attention processor to use to compute attention.
265
+
266
+ Parameters:
267
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
268
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
269
+ for **all** `Attention` layers.
270
+
271
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
272
+ processor. This is strongly recommended when setting trainable attention processors.
273
+
274
+ """
275
+ count = len(self.attn_processors.keys())
276
+
277
+ if isinstance(processor, dict) and len(processor) != count:
278
+ raise ValueError(
279
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
280
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
281
+ )
282
+
283
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
284
+ if hasattr(module, "set_processor"):
285
+ if not isinstance(processor, dict):
286
+ module.set_processor(processor)
287
+ else:
288
+ module.set_processor(processor.pop(f"{name}.processor"))
289
+
290
+ for sub_name, child in module.named_children():
291
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
292
+
293
+ for name, module in self.named_children():
294
+ fn_recursive_attn_processor(name, module, processor)
295
+
296
+ def set_default_attn_processor(self):
297
+ """
298
+ Disables custom attention processors and sets the default attention implementation.
299
+ """
300
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
301
+ processor = AttnProcessor()
302
+ else:
303
+ raise ValueError(
304
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
305
+ )
306
+
307
+ self.set_attn_processor(processor)
308
+
309
+ @apply_forward_hook
310
+ def encode(
311
+ self, x: torch.Tensor, return_dict: bool = True
312
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
313
+ """
314
+ Encode a batch of images into latents.
315
+
316
+ Args:
317
+ x (`torch.Tensor`): Input batch of images.
318
+ return_dict (`bool`, *optional*, defaults to `True`):
319
+ Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain
320
+ tuple.
321
+
322
+ Returns:
323
+ The latent representations of the encoded images. If `return_dict` is True, a
324
+ [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is
325
+ returned.
326
+ """
327
+ h = self.encoder(x)
328
+ moments = self.quant_conv(h)
329
+ posterior = DiagonalGaussianDistribution(moments)
330
+
331
+ if not return_dict:
332
+ return (posterior,)
333
+
334
+ return AutoencoderKLOutput(latent_dist=posterior)
335
+
336
+ @apply_forward_hook
337
+ def decode(
338
+ self,
339
+ z: torch.Tensor,
340
+ num_frames: int,
341
+ return_dict: bool = True,
342
+ ) -> Union[DecoderOutput, torch.Tensor]:
343
+ """
344
+ Decode a batch of images.
345
+
346
+ Args:
347
+ z (`torch.Tensor`): Input batch of latent vectors.
348
+ return_dict (`bool`, *optional*, defaults to `True`):
349
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
350
+
351
+ Returns:
352
+ [`~models.vae.DecoderOutput`] or `tuple`:
353
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
354
+ returned.
355
+
356
+ """
357
+ batch_size = z.shape[0] // num_frames
358
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
359
+ decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
360
+
361
+ if not return_dict:
362
+ return (decoded,)
363
+
364
+ return DecoderOutput(sample=decoded)
365
+
366
+ def forward(
367
+ self,
368
+ sample: torch.Tensor,
369
+ sample_posterior: bool = False,
370
+ return_dict: bool = True,
371
+ generator: Optional[torch.Generator] = None,
372
+ num_frames: int = 1,
373
+ ) -> Union[DecoderOutput, torch.Tensor]:
374
+ r"""
375
+ Args:
376
+ sample (`torch.Tensor`): Input sample.
377
+ sample_posterior (`bool`, *optional*, defaults to `False`):
378
+ Whether to sample from the posterior.
379
+ return_dict (`bool`, *optional*, defaults to `True`):
380
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
381
+ """
382
+ x = sample
383
+ posterior = self.encode(x).latent_dist
384
+ if sample_posterior:
385
+ z = posterior.sample(generator=generator)
386
+ else:
387
+ z = posterior.mode()
388
+
389
+ dec = self.decode(z, num_frames=num_frames).sample
390
+
391
+ if not return_dict:
392
+ return (dec,)
393
+
394
+ return DecoderOutput(sample=dec)
icedit/diffusers/models/autoencoders/autoencoder_oobleck.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.utils import weight_norm
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...utils import BaseOutput
25
+ from ...utils.accelerate_utils import apply_forward_hook
26
+ from ...utils.torch_utils import randn_tensor
27
+ from ..modeling_utils import ModelMixin
28
+
29
+
30
+ class Snake1d(nn.Module):
31
+ """
32
+ A 1-dimensional Snake activation function module.
33
+ """
34
+
35
+ def __init__(self, hidden_dim, logscale=True):
36
+ super().__init__()
37
+ self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
38
+ self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
39
+
40
+ self.alpha.requires_grad = True
41
+ self.beta.requires_grad = True
42
+ self.logscale = logscale
43
+
44
+ def forward(self, hidden_states):
45
+ shape = hidden_states.shape
46
+
47
+ alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
48
+ beta = self.beta if not self.logscale else torch.exp(self.beta)
49
+
50
+ hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
51
+ hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
52
+ hidden_states = hidden_states.reshape(shape)
53
+ return hidden_states
54
+
55
+
56
+ class OobleckResidualUnit(nn.Module):
57
+ """
58
+ A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
59
+ """
60
+
61
+ def __init__(self, dimension: int = 16, dilation: int = 1):
62
+ super().__init__()
63
+ pad = ((7 - 1) * dilation) // 2
64
+
65
+ self.snake1 = Snake1d(dimension)
66
+ self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
67
+ self.snake2 = Snake1d(dimension)
68
+ self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
69
+
70
+ def forward(self, hidden_state):
71
+ """
72
+ Forward pass through the residual unit.
73
+
74
+ Args:
75
+ hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
76
+ Input tensor .
77
+
78
+ Returns:
79
+ output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`)
80
+ Input tensor after passing through the residual unit.
81
+ """
82
+ output_tensor = hidden_state
83
+ output_tensor = self.conv1(self.snake1(output_tensor))
84
+ output_tensor = self.conv2(self.snake2(output_tensor))
85
+
86
+ padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
87
+ if padding > 0:
88
+ hidden_state = hidden_state[..., padding:-padding]
89
+ output_tensor = hidden_state + output_tensor
90
+ return output_tensor
91
+
92
+
93
+ class OobleckEncoderBlock(nn.Module):
94
+ """Encoder block used in Oobleck encoder."""
95
+
96
+ def __init__(self, input_dim, output_dim, stride: int = 1):
97
+ super().__init__()
98
+
99
+ self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
100
+ self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
101
+ self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
102
+ self.snake1 = Snake1d(input_dim)
103
+ self.conv1 = weight_norm(
104
+ nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
105
+ )
106
+
107
+ def forward(self, hidden_state):
108
+ hidden_state = self.res_unit1(hidden_state)
109
+ hidden_state = self.res_unit2(hidden_state)
110
+ hidden_state = self.snake1(self.res_unit3(hidden_state))
111
+ hidden_state = self.conv1(hidden_state)
112
+
113
+ return hidden_state
114
+
115
+
116
+ class OobleckDecoderBlock(nn.Module):
117
+ """Decoder block used in Oobleck decoder."""
118
+
119
+ def __init__(self, input_dim, output_dim, stride: int = 1):
120
+ super().__init__()
121
+
122
+ self.snake1 = Snake1d(input_dim)
123
+ self.conv_t1 = weight_norm(
124
+ nn.ConvTranspose1d(
125
+ input_dim,
126
+ output_dim,
127
+ kernel_size=2 * stride,
128
+ stride=stride,
129
+ padding=math.ceil(stride / 2),
130
+ )
131
+ )
132
+ self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
133
+ self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
134
+ self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
135
+
136
+ def forward(self, hidden_state):
137
+ hidden_state = self.snake1(hidden_state)
138
+ hidden_state = self.conv_t1(hidden_state)
139
+ hidden_state = self.res_unit1(hidden_state)
140
+ hidden_state = self.res_unit2(hidden_state)
141
+ hidden_state = self.res_unit3(hidden_state)
142
+
143
+ return hidden_state
144
+
145
+
146
+ class OobleckDiagonalGaussianDistribution(object):
147
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
148
+ self.parameters = parameters
149
+ self.mean, self.scale = parameters.chunk(2, dim=1)
150
+ self.std = nn.functional.softplus(self.scale) + 1e-4
151
+ self.var = self.std * self.std
152
+ self.logvar = torch.log(self.var)
153
+ self.deterministic = deterministic
154
+
155
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
156
+ # make sure sample is on the same device as the parameters and has same dtype
157
+ sample = randn_tensor(
158
+ self.mean.shape,
159
+ generator=generator,
160
+ device=self.parameters.device,
161
+ dtype=self.parameters.dtype,
162
+ )
163
+ x = self.mean + self.std * sample
164
+ return x
165
+
166
+ def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
167
+ if self.deterministic:
168
+ return torch.Tensor([0.0])
169
+ else:
170
+ if other is None:
171
+ return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
172
+ else:
173
+ normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
174
+ var_ratio = self.var / other.var
175
+ logvar_diff = self.logvar - other.logvar
176
+
177
+ kl = normalized_diff + var_ratio + logvar_diff - 1
178
+
179
+ kl = kl.sum(1).mean()
180
+ return kl
181
+
182
+ def mode(self) -> torch.Tensor:
183
+ return self.mean
184
+
185
+
186
+ @dataclass
187
+ class AutoencoderOobleckOutput(BaseOutput):
188
+ """
189
+ Output of AutoencoderOobleck encoding method.
190
+
191
+ Args:
192
+ latent_dist (`OobleckDiagonalGaussianDistribution`):
193
+ Encoded outputs of `Encoder` represented as the mean and standard deviation of
194
+ `OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents
195
+ from the distribution.
196
+ """
197
+
198
+ latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821
199
+
200
+
201
+ @dataclass
202
+ class OobleckDecoderOutput(BaseOutput):
203
+ r"""
204
+ Output of decoding method.
205
+
206
+ Args:
207
+ sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
208
+ The decoded output sample from the last layer of the model.
209
+ """
210
+
211
+ sample: torch.Tensor
212
+
213
+
214
+ class OobleckEncoder(nn.Module):
215
+ """Oobleck Encoder"""
216
+
217
+ def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
218
+ super().__init__()
219
+
220
+ strides = downsampling_ratios
221
+ channel_multiples = [1] + channel_multiples
222
+
223
+ # Create first convolution
224
+ self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
225
+
226
+ self.block = []
227
+ # Create EncoderBlocks that double channels as they downsample by `stride`
228
+ for stride_index, stride in enumerate(strides):
229
+ self.block += [
230
+ OobleckEncoderBlock(
231
+ input_dim=encoder_hidden_size * channel_multiples[stride_index],
232
+ output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
233
+ stride=stride,
234
+ )
235
+ ]
236
+
237
+ self.block = nn.ModuleList(self.block)
238
+ d_model = encoder_hidden_size * channel_multiples[-1]
239
+ self.snake1 = Snake1d(d_model)
240
+ self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
241
+
242
+ def forward(self, hidden_state):
243
+ hidden_state = self.conv1(hidden_state)
244
+
245
+ for module in self.block:
246
+ hidden_state = module(hidden_state)
247
+
248
+ hidden_state = self.snake1(hidden_state)
249
+ hidden_state = self.conv2(hidden_state)
250
+
251
+ return hidden_state
252
+
253
+
254
+ class OobleckDecoder(nn.Module):
255
+ """Oobleck Decoder"""
256
+
257
+ def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
258
+ super().__init__()
259
+
260
+ strides = upsampling_ratios
261
+ channel_multiples = [1] + channel_multiples
262
+
263
+ # Add first conv layer
264
+ self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
265
+
266
+ # Add upsampling + MRF blocks
267
+ block = []
268
+ for stride_index, stride in enumerate(strides):
269
+ block += [
270
+ OobleckDecoderBlock(
271
+ input_dim=channels * channel_multiples[len(strides) - stride_index],
272
+ output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
273
+ stride=stride,
274
+ )
275
+ ]
276
+
277
+ self.block = nn.ModuleList(block)
278
+ output_dim = channels
279
+ self.snake1 = Snake1d(output_dim)
280
+ self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
281
+
282
+ def forward(self, hidden_state):
283
+ hidden_state = self.conv1(hidden_state)
284
+
285
+ for layer in self.block:
286
+ hidden_state = layer(hidden_state)
287
+
288
+ hidden_state = self.snake1(hidden_state)
289
+ hidden_state = self.conv2(hidden_state)
290
+
291
+ return hidden_state
292
+
293
+
294
+ class AutoencoderOobleck(ModelMixin, ConfigMixin):
295
+ r"""
296
+ An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
297
+ introduced in Stable Audio.
298
+
299
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
300
+ for all models (such as downloading or saving).
301
+
302
+ Parameters:
303
+ encoder_hidden_size (`int`, *optional*, defaults to 128):
304
+ Intermediate representation dimension for the encoder.
305
+ downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
306
+ Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
307
+ channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
308
+ Multiples used to determine the hidden sizes of the hidden layers.
309
+ decoder_channels (`int`, *optional*, defaults to 128):
310
+ Intermediate representation dimension for the decoder.
311
+ decoder_input_channels (`int`, *optional*, defaults to 64):
312
+ Input dimension for the decoder. Corresponds to the latent dimension.
313
+ audio_channels (`int`, *optional*, defaults to 2):
314
+ Number of channels in the audio data. Either 1 for mono or 2 for stereo.
315
+ sampling_rate (`int`, *optional*, defaults to 44100):
316
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
317
+ """
318
+
319
+ _supports_gradient_checkpointing = False
320
+
321
+ @register_to_config
322
+ def __init__(
323
+ self,
324
+ encoder_hidden_size=128,
325
+ downsampling_ratios=[2, 4, 4, 8, 8],
326
+ channel_multiples=[1, 2, 4, 8, 16],
327
+ decoder_channels=128,
328
+ decoder_input_channels=64,
329
+ audio_channels=2,
330
+ sampling_rate=44100,
331
+ ):
332
+ super().__init__()
333
+
334
+ self.encoder_hidden_size = encoder_hidden_size
335
+ self.downsampling_ratios = downsampling_ratios
336
+ self.decoder_channels = decoder_channels
337
+ self.upsampling_ratios = downsampling_ratios[::-1]
338
+ self.hop_length = int(np.prod(downsampling_ratios))
339
+ self.sampling_rate = sampling_rate
340
+
341
+ self.encoder = OobleckEncoder(
342
+ encoder_hidden_size=encoder_hidden_size,
343
+ audio_channels=audio_channels,
344
+ downsampling_ratios=downsampling_ratios,
345
+ channel_multiples=channel_multiples,
346
+ )
347
+
348
+ self.decoder = OobleckDecoder(
349
+ channels=decoder_channels,
350
+ input_channels=decoder_input_channels,
351
+ audio_channels=audio_channels,
352
+ upsampling_ratios=self.upsampling_ratios,
353
+ channel_multiples=channel_multiples,
354
+ )
355
+
356
+ self.use_slicing = False
357
+
358
+ def enable_slicing(self):
359
+ r"""
360
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
361
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
362
+ """
363
+ self.use_slicing = True
364
+
365
+ def disable_slicing(self):
366
+ r"""
367
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
368
+ decoding in one step.
369
+ """
370
+ self.use_slicing = False
371
+
372
+ @apply_forward_hook
373
+ def encode(
374
+ self, x: torch.Tensor, return_dict: bool = True
375
+ ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
376
+ """
377
+ Encode a batch of images into latents.
378
+
379
+ Args:
380
+ x (`torch.Tensor`): Input batch of images.
381
+ return_dict (`bool`, *optional*, defaults to `True`):
382
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
383
+
384
+ Returns:
385
+ The latent representations of the encoded images. If `return_dict` is True, a
386
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
387
+ """
388
+ if self.use_slicing and x.shape[0] > 1:
389
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
390
+ h = torch.cat(encoded_slices)
391
+ else:
392
+ h = self.encoder(x)
393
+
394
+ posterior = OobleckDiagonalGaussianDistribution(h)
395
+
396
+ if not return_dict:
397
+ return (posterior,)
398
+
399
+ return AutoencoderOobleckOutput(latent_dist=posterior)
400
+
401
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
402
+ dec = self.decoder(z)
403
+
404
+ if not return_dict:
405
+ return (dec,)
406
+
407
+ return OobleckDecoderOutput(sample=dec)
408
+
409
+ @apply_forward_hook
410
+ def decode(
411
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
412
+ ) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
413
+ """
414
+ Decode a batch of images.
415
+
416
+ Args:
417
+ z (`torch.Tensor`): Input batch of latent vectors.
418
+ return_dict (`bool`, *optional*, defaults to `True`):
419
+ Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple.
420
+
421
+ Returns:
422
+ [`~models.vae.OobleckDecoderOutput`] or `tuple`:
423
+ If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple`
424
+ is returned.
425
+
426
+ """
427
+ if self.use_slicing and z.shape[0] > 1:
428
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
429
+ decoded = torch.cat(decoded_slices)
430
+ else:
431
+ decoded = self._decode(z).sample
432
+
433
+ if not return_dict:
434
+ return (decoded,)
435
+
436
+ return OobleckDecoderOutput(sample=decoded)
437
+
438
+ def forward(
439
+ self,
440
+ sample: torch.Tensor,
441
+ sample_posterior: bool = False,
442
+ return_dict: bool = True,
443
+ generator: Optional[torch.Generator] = None,
444
+ ) -> Union[OobleckDecoderOutput, torch.Tensor]:
445
+ r"""
446
+ Args:
447
+ sample (`torch.Tensor`): Input sample.
448
+ sample_posterior (`bool`, *optional*, defaults to `False`):
449
+ Whether to sample from the posterior.
450
+ return_dict (`bool`, *optional*, defaults to `True`):
451
+ Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple.
452
+ """
453
+ x = sample
454
+ posterior = self.encode(x).latent_dist
455
+ if sample_posterior:
456
+ z = posterior.sample(generator=generator)
457
+ else:
458
+ z = posterior.mode()
459
+ dec = self.decode(z).sample
460
+
461
+ if not return_dict:
462
+ return (dec,)
463
+
464
+ return OobleckDecoderOutput(sample=dec)
icedit/diffusers/models/autoencoders/autoencoder_tiny.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...utils import BaseOutput
23
+ from ...utils.accelerate_utils import apply_forward_hook
24
+ from ..modeling_utils import ModelMixin
25
+ from .vae import DecoderOutput, DecoderTiny, EncoderTiny
26
+
27
+
28
+ @dataclass
29
+ class AutoencoderTinyOutput(BaseOutput):
30
+ """
31
+ Output of AutoencoderTiny encoding method.
32
+
33
+ Args:
34
+ latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
35
+
36
+ """
37
+
38
+ latents: torch.Tensor
39
+
40
+
41
+ class AutoencoderTiny(ModelMixin, ConfigMixin):
42
+ r"""
43
+ A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
44
+
45
+ [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
46
+
47
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
48
+ all models (such as downloading or saving).
49
+
50
+ Parameters:
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
52
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
+ encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
54
+ Tuple of integers representing the number of output channels for each encoder block. The length of the
55
+ tuple should be equal to the number of encoder blocks.
56
+ decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
57
+ Tuple of integers representing the number of output channels for each decoder block. The length of the
58
+ tuple should be equal to the number of decoder blocks.
59
+ act_fn (`str`, *optional*, defaults to `"relu"`):
60
+ Activation function to be used throughout the model.
61
+ latent_channels (`int`, *optional*, defaults to 4):
62
+ Number of channels in the latent representation. The latent space acts as a compressed representation of
63
+ the input image.
64
+ upsampling_scaling_factor (`int`, *optional*, defaults to 2):
65
+ Scaling factor for upsampling in the decoder. It determines the size of the output image during the
66
+ upsampling process.
67
+ num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
68
+ Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
69
+ length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
70
+ number of encoder blocks.
71
+ num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
72
+ Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
73
+ length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
74
+ number of decoder blocks.
75
+ latent_magnitude (`float`, *optional*, defaults to 3.0):
76
+ Magnitude of the latent representation. This parameter scales the latent representation values to control
77
+ the extent of information preservation.
78
+ latent_shift (float, *optional*, defaults to 0.5):
79
+ Shift applied to the latent representation. This parameter controls the center of the latent space.
80
+ scaling_factor (`float`, *optional*, defaults to 1.0):
81
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
82
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
83
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
84
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
85
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
86
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
87
+ however, no such scaling factor was used, hence the value of 1.0 as the default.
88
+ force_upcast (`bool`, *optional*, default to `False`):
89
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
90
+ can be fine-tuned / trained to a lower range without losing too much precision, in which case
91
+ `force_upcast` can be set to `False` (see this fp16-friendly
92
+ [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
93
+ """
94
+
95
+ _supports_gradient_checkpointing = True
96
+
97
+ @register_to_config
98
+ def __init__(
99
+ self,
100
+ in_channels: int = 3,
101
+ out_channels: int = 3,
102
+ encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
103
+ decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
104
+ act_fn: str = "relu",
105
+ upsample_fn: str = "nearest",
106
+ latent_channels: int = 4,
107
+ upsampling_scaling_factor: int = 2,
108
+ num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
109
+ num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
110
+ latent_magnitude: int = 3,
111
+ latent_shift: float = 0.5,
112
+ force_upcast: bool = False,
113
+ scaling_factor: float = 1.0,
114
+ shift_factor: float = 0.0,
115
+ ):
116
+ super().__init__()
117
+
118
+ if len(encoder_block_out_channels) != len(num_encoder_blocks):
119
+ raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
120
+ if len(decoder_block_out_channels) != len(num_decoder_blocks):
121
+ raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
122
+
123
+ self.encoder = EncoderTiny(
124
+ in_channels=in_channels,
125
+ out_channels=latent_channels,
126
+ num_blocks=num_encoder_blocks,
127
+ block_out_channels=encoder_block_out_channels,
128
+ act_fn=act_fn,
129
+ )
130
+
131
+ self.decoder = DecoderTiny(
132
+ in_channels=latent_channels,
133
+ out_channels=out_channels,
134
+ num_blocks=num_decoder_blocks,
135
+ block_out_channels=decoder_block_out_channels,
136
+ upsampling_scaling_factor=upsampling_scaling_factor,
137
+ act_fn=act_fn,
138
+ upsample_fn=upsample_fn,
139
+ )
140
+
141
+ self.latent_magnitude = latent_magnitude
142
+ self.latent_shift = latent_shift
143
+ self.scaling_factor = scaling_factor
144
+
145
+ self.use_slicing = False
146
+ self.use_tiling = False
147
+
148
+ # only relevant if vae tiling is enabled
149
+ self.spatial_scale_factor = 2**out_channels
150
+ self.tile_overlap_factor = 0.125
151
+ self.tile_sample_min_size = 512
152
+ self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
153
+
154
+ self.register_to_config(block_out_channels=decoder_block_out_channels)
155
+ self.register_to_config(force_upcast=False)
156
+
157
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
158
+ if isinstance(module, (EncoderTiny, DecoderTiny)):
159
+ module.gradient_checkpointing = value
160
+
161
+ def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
162
+ """raw latents -> [0, 1]"""
163
+ return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
164
+
165
+ def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
166
+ """[0, 1] -> raw latents"""
167
+ return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
168
+
169
+ def enable_slicing(self) -> None:
170
+ r"""
171
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
172
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
173
+ """
174
+ self.use_slicing = True
175
+
176
+ def disable_slicing(self) -> None:
177
+ r"""
178
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
179
+ decoding in one step.
180
+ """
181
+ self.use_slicing = False
182
+
183
+ def enable_tiling(self, use_tiling: bool = True) -> None:
184
+ r"""
185
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
186
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
187
+ processing larger images.
188
+ """
189
+ self.use_tiling = use_tiling
190
+
191
+ def disable_tiling(self) -> None:
192
+ r"""
193
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
194
+ decoding in one step.
195
+ """
196
+ self.enable_tiling(False)
197
+
198
+ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
199
+ r"""Encode a batch of images using a tiled encoder.
200
+
201
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
202
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
203
+ tiles overlap and are blended together to form a smooth output.
204
+
205
+ Args:
206
+ x (`torch.Tensor`): Input batch of images.
207
+
208
+ Returns:
209
+ `torch.Tensor`: Encoded batch of images.
210
+ """
211
+ # scale of encoder output relative to input
212
+ sf = self.spatial_scale_factor
213
+ tile_size = self.tile_sample_min_size
214
+
215
+ # number of pixels to blend and to traverse between tile
216
+ blend_size = int(tile_size * self.tile_overlap_factor)
217
+ traverse_size = tile_size - blend_size
218
+
219
+ # tiles index (up/left)
220
+ ti = range(0, x.shape[-2], traverse_size)
221
+ tj = range(0, x.shape[-1], traverse_size)
222
+
223
+ # mask for blending
224
+ blend_masks = torch.stack(
225
+ torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
226
+ )
227
+ blend_masks = blend_masks.clamp(0, 1).to(x.device)
228
+
229
+ # output array
230
+ out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
231
+ for i in ti:
232
+ for j in tj:
233
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
234
+ # tile result
235
+ tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
236
+ tile = self.encoder(tile_in)
237
+ h, w = tile.shape[-2], tile.shape[-1]
238
+ # blend tile result into output
239
+ blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
240
+ blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
241
+ blend_mask = blend_mask_i * blend_mask_j
242
+ tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
243
+ tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
244
+ return out
245
+
246
+ def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor:
247
+ r"""Encode a batch of images using a tiled encoder.
248
+
249
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
250
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
251
+ tiles overlap and are blended together to form a smooth output.
252
+
253
+ Args:
254
+ x (`torch.Tensor`): Input batch of images.
255
+
256
+ Returns:
257
+ `torch.Tensor`: Encoded batch of images.
258
+ """
259
+ # scale of decoder output relative to input
260
+ sf = self.spatial_scale_factor
261
+ tile_size = self.tile_latent_min_size
262
+
263
+ # number of pixels to blend and to traverse between tiles
264
+ blend_size = int(tile_size * self.tile_overlap_factor)
265
+ traverse_size = tile_size - blend_size
266
+
267
+ # tiles index (up/left)
268
+ ti = range(0, x.shape[-2], traverse_size)
269
+ tj = range(0, x.shape[-1], traverse_size)
270
+
271
+ # mask for blending
272
+ blend_masks = torch.stack(
273
+ torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
274
+ )
275
+ blend_masks = blend_masks.clamp(0, 1).to(x.device)
276
+
277
+ # output array
278
+ out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
279
+ for i in ti:
280
+ for j in tj:
281
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
282
+ # tile result
283
+ tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
284
+ tile = self.decoder(tile_in)
285
+ h, w = tile.shape[-2], tile.shape[-1]
286
+ # blend tile result into output
287
+ blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
288
+ blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
289
+ blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
290
+ tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
291
+ return out
292
+
293
+ @apply_forward_hook
294
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
295
+ if self.use_slicing and x.shape[0] > 1:
296
+ output = [
297
+ self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
298
+ ]
299
+ output = torch.cat(output)
300
+ else:
301
+ output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
302
+
303
+ if not return_dict:
304
+ return (output,)
305
+
306
+ return AutoencoderTinyOutput(latents=output)
307
+
308
+ @apply_forward_hook
309
+ def decode(
310
+ self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
311
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
312
+ if self.use_slicing and x.shape[0] > 1:
313
+ output = [
314
+ self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
315
+ ]
316
+ output = torch.cat(output)
317
+ else:
318
+ output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
319
+
320
+ if not return_dict:
321
+ return (output,)
322
+
323
+ return DecoderOutput(sample=output)
324
+
325
+ def forward(
326
+ self,
327
+ sample: torch.Tensor,
328
+ return_dict: bool = True,
329
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
330
+ r"""
331
+ Args:
332
+ sample (`torch.Tensor`): Input sample.
333
+ return_dict (`bool`, *optional*, defaults to `True`):
334
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
335
+ """
336
+ enc = self.encode(sample).latents
337
+
338
+ # scale latents to be in [0, 1], then quantize latents to a byte tensor,
339
+ # as if we were storing the latents in an RGBA uint8 image.
340
+ scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
341
+
342
+ # unquantize latents back into [0, 1], then unscale latents back to their original range,
343
+ # as if we were loading the latents from an RGBA uint8 image.
344
+ unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
345
+
346
+ dec = self.decode(unscaled_enc).sample
347
+
348
+ if not return_dict:
349
+ return (dec,)
350
+ return DecoderOutput(sample=dec)
icedit/diffusers/models/autoencoders/consistency_decoder_vae.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...schedulers import ConsistencyDecoderScheduler
23
+ from ...utils import BaseOutput
24
+ from ...utils.accelerate_utils import apply_forward_hook
25
+ from ...utils.torch_utils import randn_tensor
26
+ from ..attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ )
33
+ from ..modeling_utils import ModelMixin
34
+ from ..unets.unet_2d import UNet2DModel
35
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
36
+
37
+
38
+ @dataclass
39
+ class ConsistencyDecoderVAEOutput(BaseOutput):
40
+ """
41
+ Output of encoding method.
42
+
43
+ Args:
44
+ latent_dist (`DiagonalGaussianDistribution`):
45
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
46
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
47
+ """
48
+
49
+ latent_dist: "DiagonalGaussianDistribution"
50
+
51
+
52
+ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
53
+ r"""
54
+ The consistency decoder used with DALL-E 3.
55
+
56
+ Examples:
57
+ ```py
58
+ >>> import torch
59
+ >>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE
60
+
61
+ >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
62
+ >>> pipe = StableDiffusionPipeline.from_pretrained(
63
+ ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
64
+ ... ).to("cuda")
65
+
66
+ >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
67
+ >>> image
68
+ ```
69
+ """
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ scaling_factor: float = 0.18215,
75
+ latent_channels: int = 4,
76
+ sample_size: int = 32,
77
+ encoder_act_fn: str = "silu",
78
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
79
+ encoder_double_z: bool = True,
80
+ encoder_down_block_types: Tuple[str, ...] = (
81
+ "DownEncoderBlock2D",
82
+ "DownEncoderBlock2D",
83
+ "DownEncoderBlock2D",
84
+ "DownEncoderBlock2D",
85
+ ),
86
+ encoder_in_channels: int = 3,
87
+ encoder_layers_per_block: int = 2,
88
+ encoder_norm_num_groups: int = 32,
89
+ encoder_out_channels: int = 4,
90
+ decoder_add_attention: bool = False,
91
+ decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
92
+ decoder_down_block_types: Tuple[str, ...] = (
93
+ "ResnetDownsampleBlock2D",
94
+ "ResnetDownsampleBlock2D",
95
+ "ResnetDownsampleBlock2D",
96
+ "ResnetDownsampleBlock2D",
97
+ ),
98
+ decoder_downsample_padding: int = 1,
99
+ decoder_in_channels: int = 7,
100
+ decoder_layers_per_block: int = 3,
101
+ decoder_norm_eps: float = 1e-05,
102
+ decoder_norm_num_groups: int = 32,
103
+ decoder_num_train_timesteps: int = 1024,
104
+ decoder_out_channels: int = 6,
105
+ decoder_resnet_time_scale_shift: str = "scale_shift",
106
+ decoder_time_embedding_type: str = "learned",
107
+ decoder_up_block_types: Tuple[str, ...] = (
108
+ "ResnetUpsampleBlock2D",
109
+ "ResnetUpsampleBlock2D",
110
+ "ResnetUpsampleBlock2D",
111
+ "ResnetUpsampleBlock2D",
112
+ ),
113
+ ):
114
+ super().__init__()
115
+ self.encoder = Encoder(
116
+ act_fn=encoder_act_fn,
117
+ block_out_channels=encoder_block_out_channels,
118
+ double_z=encoder_double_z,
119
+ down_block_types=encoder_down_block_types,
120
+ in_channels=encoder_in_channels,
121
+ layers_per_block=encoder_layers_per_block,
122
+ norm_num_groups=encoder_norm_num_groups,
123
+ out_channels=encoder_out_channels,
124
+ )
125
+
126
+ self.decoder_unet = UNet2DModel(
127
+ add_attention=decoder_add_attention,
128
+ block_out_channels=decoder_block_out_channels,
129
+ down_block_types=decoder_down_block_types,
130
+ downsample_padding=decoder_downsample_padding,
131
+ in_channels=decoder_in_channels,
132
+ layers_per_block=decoder_layers_per_block,
133
+ norm_eps=decoder_norm_eps,
134
+ norm_num_groups=decoder_norm_num_groups,
135
+ num_train_timesteps=decoder_num_train_timesteps,
136
+ out_channels=decoder_out_channels,
137
+ resnet_time_scale_shift=decoder_resnet_time_scale_shift,
138
+ time_embedding_type=decoder_time_embedding_type,
139
+ up_block_types=decoder_up_block_types,
140
+ )
141
+ self.decoder_scheduler = ConsistencyDecoderScheduler()
142
+ self.register_to_config(block_out_channels=encoder_block_out_channels)
143
+ self.register_to_config(force_upcast=False)
144
+ self.register_buffer(
145
+ "means",
146
+ torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
147
+ persistent=False,
148
+ )
149
+ self.register_buffer(
150
+ "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
151
+ )
152
+
153
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
154
+
155
+ self.use_slicing = False
156
+ self.use_tiling = False
157
+
158
+ # only relevant if vae tiling is enabled
159
+ self.tile_sample_min_size = self.config.sample_size
160
+ sample_size = (
161
+ self.config.sample_size[0]
162
+ if isinstance(self.config.sample_size, (list, tuple))
163
+ else self.config.sample_size
164
+ )
165
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
166
+ self.tile_overlap_factor = 0.25
167
+
168
+ # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
169
+ def enable_tiling(self, use_tiling: bool = True):
170
+ r"""
171
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
172
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
173
+ processing larger images.
174
+ """
175
+ self.use_tiling = use_tiling
176
+
177
+ # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
178
+ def disable_tiling(self):
179
+ r"""
180
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
181
+ decoding in one step.
182
+ """
183
+ self.enable_tiling(False)
184
+
185
+ # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
186
+ def enable_slicing(self):
187
+ r"""
188
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
189
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
190
+ """
191
+ self.use_slicing = True
192
+
193
+ # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
194
+ def disable_slicing(self):
195
+ r"""
196
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
197
+ decoding in one step.
198
+ """
199
+ self.use_slicing = False
200
+
201
+ @property
202
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
203
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
204
+ r"""
205
+ Returns:
206
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
207
+ indexed by its weight name.
208
+ """
209
+ # set recursively
210
+ processors = {}
211
+
212
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
213
+ if hasattr(module, "get_processor"):
214
+ processors[f"{name}.processor"] = module.get_processor()
215
+
216
+ for sub_name, child in module.named_children():
217
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
218
+
219
+ return processors
220
+
221
+ for name, module in self.named_children():
222
+ fn_recursive_add_processors(name, module, processors)
223
+
224
+ return processors
225
+
226
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
227
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
228
+ r"""
229
+ Sets the attention processor to use to compute attention.
230
+
231
+ Parameters:
232
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
233
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
234
+ for **all** `Attention` layers.
235
+
236
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
237
+ processor. This is strongly recommended when setting trainable attention processors.
238
+
239
+ """
240
+ count = len(self.attn_processors.keys())
241
+
242
+ if isinstance(processor, dict) and len(processor) != count:
243
+ raise ValueError(
244
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
245
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
246
+ )
247
+
248
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
249
+ if hasattr(module, "set_processor"):
250
+ if not isinstance(processor, dict):
251
+ module.set_processor(processor)
252
+ else:
253
+ module.set_processor(processor.pop(f"{name}.processor"))
254
+
255
+ for sub_name, child in module.named_children():
256
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
257
+
258
+ for name, module in self.named_children():
259
+ fn_recursive_attn_processor(name, module, processor)
260
+
261
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
262
+ def set_default_attn_processor(self):
263
+ """
264
+ Disables custom attention processors and sets the default attention implementation.
265
+ """
266
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
267
+ processor = AttnAddedKVProcessor()
268
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
269
+ processor = AttnProcessor()
270
+ else:
271
+ raise ValueError(
272
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
273
+ )
274
+
275
+ self.set_attn_processor(processor)
276
+
277
+ @apply_forward_hook
278
+ def encode(
279
+ self, x: torch.Tensor, return_dict: bool = True
280
+ ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
281
+ """
282
+ Encode a batch of images into latents.
283
+
284
+ Args:
285
+ x (`torch.Tensor`): Input batch of images.
286
+ return_dict (`bool`, *optional*, defaults to `True`):
287
+ Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
288
+ instead of a plain tuple.
289
+
290
+ Returns:
291
+ The latent representations of the encoded images. If `return_dict` is True, a
292
+ [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a
293
+ plain `tuple` is returned.
294
+ """
295
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
296
+ return self.tiled_encode(x, return_dict=return_dict)
297
+
298
+ if self.use_slicing and x.shape[0] > 1:
299
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
300
+ h = torch.cat(encoded_slices)
301
+ else:
302
+ h = self.encoder(x)
303
+
304
+ moments = self.quant_conv(h)
305
+ posterior = DiagonalGaussianDistribution(moments)
306
+
307
+ if not return_dict:
308
+ return (posterior,)
309
+
310
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
311
+
312
+ @apply_forward_hook
313
+ def decode(
314
+ self,
315
+ z: torch.Tensor,
316
+ generator: Optional[torch.Generator] = None,
317
+ return_dict: bool = True,
318
+ num_inference_steps: int = 2,
319
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
320
+ """
321
+ Decodes the input latent vector `z` using the consistency decoder VAE model.
322
+
323
+ Args:
324
+ z (torch.Tensor): The input latent vector.
325
+ generator (Optional[torch.Generator]): The random number generator. Default is None.
326
+ return_dict (bool): Whether to return the output as a dictionary. Default is True.
327
+ num_inference_steps (int): The number of inference steps. Default is 2.
328
+
329
+ Returns:
330
+ Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
331
+
332
+ """
333
+ z = (z * self.config.scaling_factor - self.means) / self.stds
334
+
335
+ scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
336
+ z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
337
+
338
+ batch_size, _, height, width = z.shape
339
+
340
+ self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
341
+
342
+ x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
343
+ (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
344
+ )
345
+
346
+ for t in self.decoder_scheduler.timesteps:
347
+ model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
348
+ model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
349
+ prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
350
+ x_t = prev_sample
351
+
352
+ x_0 = x_t
353
+
354
+ if not return_dict:
355
+ return (x_0,)
356
+
357
+ return DecoderOutput(sample=x_0)
358
+
359
+ # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
360
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
361
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
362
+ for y in range(blend_extent):
363
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
364
+ return b
365
+
366
+ # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
367
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
368
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
369
+ for x in range(blend_extent):
370
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
371
+ return b
372
+
373
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
374
+ r"""Encode a batch of images using a tiled encoder.
375
+
376
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
377
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
378
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
379
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
380
+ output, but they should be much less noticeable.
381
+
382
+ Args:
383
+ x (`torch.Tensor`): Input batch of images.
384
+ return_dict (`bool`, *optional*, defaults to `True`):
385
+ Whether or not to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
386
+ instead of a plain tuple.
387
+
388
+ Returns:
389
+ [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
390
+ If return_dict is True, a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
391
+ is returned, otherwise a plain `tuple` is returned.
392
+ """
393
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
394
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
395
+ row_limit = self.tile_latent_min_size - blend_extent
396
+
397
+ # Split the image into 512x512 tiles and encode them separately.
398
+ rows = []
399
+ for i in range(0, x.shape[2], overlap_size):
400
+ row = []
401
+ for j in range(0, x.shape[3], overlap_size):
402
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
403
+ tile = self.encoder(tile)
404
+ tile = self.quant_conv(tile)
405
+ row.append(tile)
406
+ rows.append(row)
407
+ result_rows = []
408
+ for i, row in enumerate(rows):
409
+ result_row = []
410
+ for j, tile in enumerate(row):
411
+ # blend the above tile and the left tile
412
+ # to the current tile and add the current tile to the result row
413
+ if i > 0:
414
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
415
+ if j > 0:
416
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
417
+ result_row.append(tile[:, :, :row_limit, :row_limit])
418
+ result_rows.append(torch.cat(result_row, dim=3))
419
+
420
+ moments = torch.cat(result_rows, dim=2)
421
+ posterior = DiagonalGaussianDistribution(moments)
422
+
423
+ if not return_dict:
424
+ return (posterior,)
425
+
426
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
427
+
428
+ def forward(
429
+ self,
430
+ sample: torch.Tensor,
431
+ sample_posterior: bool = False,
432
+ return_dict: bool = True,
433
+ generator: Optional[torch.Generator] = None,
434
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
435
+ r"""
436
+ Args:
437
+ sample (`torch.Tensor`): Input sample.
438
+ sample_posterior (`bool`, *optional*, defaults to `False`):
439
+ Whether to sample from the posterior.
440
+ return_dict (`bool`, *optional*, defaults to `True`):
441
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
442
+ generator (`torch.Generator`, *optional*, defaults to `None`):
443
+ Generator to use for sampling.
444
+
445
+ Returns:
446
+ [`DecoderOutput`] or `tuple`:
447
+ If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
448
+ """
449
+ x = sample
450
+ posterior = self.encode(x).latent_dist
451
+ if sample_posterior:
452
+ z = posterior.sample(generator=generator)
453
+ else:
454
+ z = posterior.mode()
455
+ dec = self.decode(z, generator=generator).sample
456
+
457
+ if not return_dict:
458
+ return (dec,)
459
+
460
+ return DecoderOutput(sample=dec)
icedit/diffusers/models/autoencoders/vae.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...utils import BaseOutput, is_torch_version
22
+ from ...utils.torch_utils import randn_tensor
23
+ from ..activations import get_activation
24
+ from ..attention_processor import SpatialNorm
25
+ from ..unets.unet_2d_blocks import (
26
+ AutoencoderTinyBlock,
27
+ UNetMidBlock2D,
28
+ get_down_block,
29
+ get_up_block,
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class EncoderOutput(BaseOutput):
35
+ r"""
36
+ Output of encoding method.
37
+
38
+ Args:
39
+ latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`):
40
+ The encoded latent.
41
+ """
42
+
43
+ latent: torch.Tensor
44
+
45
+
46
+ @dataclass
47
+ class DecoderOutput(BaseOutput):
48
+ r"""
49
+ Output of decoding method.
50
+
51
+ Args:
52
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
53
+ The decoded output sample from the last layer of the model.
54
+ """
55
+
56
+ sample: torch.Tensor
57
+ commit_loss: Optional[torch.FloatTensor] = None
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ r"""
62
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
63
+
64
+ Args:
65
+ in_channels (`int`, *optional*, defaults to 3):
66
+ The number of input channels.
67
+ out_channels (`int`, *optional*, defaults to 3):
68
+ The number of output channels.
69
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
70
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
71
+ options.
72
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
73
+ The number of output channels for each block.
74
+ layers_per_block (`int`, *optional*, defaults to 2):
75
+ The number of layers per block.
76
+ norm_num_groups (`int`, *optional*, defaults to 32):
77
+ The number of groups for normalization.
78
+ act_fn (`str`, *optional*, defaults to `"silu"`):
79
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
80
+ double_z (`bool`, *optional*, defaults to `True`):
81
+ Whether to double the number of output channels for the last block.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ in_channels: int = 3,
87
+ out_channels: int = 3,
88
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
89
+ block_out_channels: Tuple[int, ...] = (64,),
90
+ layers_per_block: int = 2,
91
+ norm_num_groups: int = 32,
92
+ act_fn: str = "silu",
93
+ double_z: bool = True,
94
+ mid_block_add_attention=True,
95
+ ):
96
+ super().__init__()
97
+ self.layers_per_block = layers_per_block
98
+
99
+ self.conv_in = nn.Conv2d(
100
+ in_channels,
101
+ block_out_channels[0],
102
+ kernel_size=3,
103
+ stride=1,
104
+ padding=1,
105
+ )
106
+
107
+ self.down_blocks = nn.ModuleList([])
108
+
109
+ # down
110
+ output_channel = block_out_channels[0]
111
+ for i, down_block_type in enumerate(down_block_types):
112
+ input_channel = output_channel
113
+ output_channel = block_out_channels[i]
114
+ is_final_block = i == len(block_out_channels) - 1
115
+
116
+ down_block = get_down_block(
117
+ down_block_type,
118
+ num_layers=self.layers_per_block,
119
+ in_channels=input_channel,
120
+ out_channels=output_channel,
121
+ add_downsample=not is_final_block,
122
+ resnet_eps=1e-6,
123
+ downsample_padding=0,
124
+ resnet_act_fn=act_fn,
125
+ resnet_groups=norm_num_groups,
126
+ attention_head_dim=output_channel,
127
+ temb_channels=None,
128
+ )
129
+ self.down_blocks.append(down_block)
130
+
131
+ # mid
132
+ self.mid_block = UNetMidBlock2D(
133
+ in_channels=block_out_channels[-1],
134
+ resnet_eps=1e-6,
135
+ resnet_act_fn=act_fn,
136
+ output_scale_factor=1,
137
+ resnet_time_scale_shift="default",
138
+ attention_head_dim=block_out_channels[-1],
139
+ resnet_groups=norm_num_groups,
140
+ temb_channels=None,
141
+ add_attention=mid_block_add_attention,
142
+ )
143
+
144
+ # out
145
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
146
+ self.conv_act = nn.SiLU()
147
+
148
+ conv_out_channels = 2 * out_channels if double_z else out_channels
149
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
150
+
151
+ self.gradient_checkpointing = False
152
+
153
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
154
+ r"""The forward method of the `Encoder` class."""
155
+
156
+ sample = self.conv_in(sample)
157
+
158
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
159
+
160
+ def create_custom_forward(module):
161
+ def custom_forward(*inputs):
162
+ return module(*inputs)
163
+
164
+ return custom_forward
165
+
166
+ # down
167
+ if is_torch_version(">=", "1.11.0"):
168
+ for down_block in self.down_blocks:
169
+ sample = torch.utils.checkpoint.checkpoint(
170
+ create_custom_forward(down_block), sample, use_reentrant=False
171
+ )
172
+ # middle
173
+ sample = torch.utils.checkpoint.checkpoint(
174
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
175
+ )
176
+ else:
177
+ for down_block in self.down_blocks:
178
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
179
+ # middle
180
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
181
+
182
+ else:
183
+ # down
184
+ for down_block in self.down_blocks:
185
+ sample = down_block(sample)
186
+
187
+ # middle
188
+ sample = self.mid_block(sample)
189
+
190
+ # post-process
191
+ sample = self.conv_norm_out(sample)
192
+ sample = self.conv_act(sample)
193
+ sample = self.conv_out(sample)
194
+
195
+ return sample
196
+
197
+
198
+ class Decoder(nn.Module):
199
+ r"""
200
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
201
+
202
+ Args:
203
+ in_channels (`int`, *optional*, defaults to 3):
204
+ The number of input channels.
205
+ out_channels (`int`, *optional*, defaults to 3):
206
+ The number of output channels.
207
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
208
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
209
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
210
+ The number of output channels for each block.
211
+ layers_per_block (`int`, *optional*, defaults to 2):
212
+ The number of layers per block.
213
+ norm_num_groups (`int`, *optional*, defaults to 32):
214
+ The number of groups for normalization.
215
+ act_fn (`str`, *optional*, defaults to `"silu"`):
216
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
217
+ norm_type (`str`, *optional*, defaults to `"group"`):
218
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ in_channels: int = 3,
224
+ out_channels: int = 3,
225
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
226
+ block_out_channels: Tuple[int, ...] = (64,),
227
+ layers_per_block: int = 2,
228
+ norm_num_groups: int = 32,
229
+ act_fn: str = "silu",
230
+ norm_type: str = "group", # group, spatial
231
+ mid_block_add_attention=True,
232
+ ):
233
+ super().__init__()
234
+ self.layers_per_block = layers_per_block
235
+
236
+ self.conv_in = nn.Conv2d(
237
+ in_channels,
238
+ block_out_channels[-1],
239
+ kernel_size=3,
240
+ stride=1,
241
+ padding=1,
242
+ )
243
+
244
+ self.up_blocks = nn.ModuleList([])
245
+
246
+ temb_channels = in_channels if norm_type == "spatial" else None
247
+
248
+ # mid
249
+ self.mid_block = UNetMidBlock2D(
250
+ in_channels=block_out_channels[-1],
251
+ resnet_eps=1e-6,
252
+ resnet_act_fn=act_fn,
253
+ output_scale_factor=1,
254
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
255
+ attention_head_dim=block_out_channels[-1],
256
+ resnet_groups=norm_num_groups,
257
+ temb_channels=temb_channels,
258
+ add_attention=mid_block_add_attention,
259
+ )
260
+
261
+ # up
262
+ reversed_block_out_channels = list(reversed(block_out_channels))
263
+ output_channel = reversed_block_out_channels[0]
264
+ for i, up_block_type in enumerate(up_block_types):
265
+ prev_output_channel = output_channel
266
+ output_channel = reversed_block_out_channels[i]
267
+
268
+ is_final_block = i == len(block_out_channels) - 1
269
+
270
+ up_block = get_up_block(
271
+ up_block_type,
272
+ num_layers=self.layers_per_block + 1,
273
+ in_channels=prev_output_channel,
274
+ out_channels=output_channel,
275
+ prev_output_channel=None,
276
+ add_upsample=not is_final_block,
277
+ resnet_eps=1e-6,
278
+ resnet_act_fn=act_fn,
279
+ resnet_groups=norm_num_groups,
280
+ attention_head_dim=output_channel,
281
+ temb_channels=temb_channels,
282
+ resnet_time_scale_shift=norm_type,
283
+ )
284
+ self.up_blocks.append(up_block)
285
+ prev_output_channel = output_channel
286
+
287
+ # out
288
+ if norm_type == "spatial":
289
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
290
+ else:
291
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
292
+ self.conv_act = nn.SiLU()
293
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
294
+
295
+ self.gradient_checkpointing = False
296
+
297
+ def forward(
298
+ self,
299
+ sample: torch.Tensor,
300
+ latent_embeds: Optional[torch.Tensor] = None,
301
+ ) -> torch.Tensor:
302
+ r"""The forward method of the `Decoder` class."""
303
+
304
+ sample = self.conv_in(sample)
305
+
306
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
307
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
308
+
309
+ def create_custom_forward(module):
310
+ def custom_forward(*inputs):
311
+ return module(*inputs)
312
+
313
+ return custom_forward
314
+
315
+ if is_torch_version(">=", "1.11.0"):
316
+ # middle
317
+ sample = torch.utils.checkpoint.checkpoint(
318
+ create_custom_forward(self.mid_block),
319
+ sample,
320
+ latent_embeds,
321
+ use_reentrant=False,
322
+ )
323
+ sample = sample.to(upscale_dtype)
324
+
325
+ # up
326
+ for up_block in self.up_blocks:
327
+ sample = torch.utils.checkpoint.checkpoint(
328
+ create_custom_forward(up_block),
329
+ sample,
330
+ latent_embeds,
331
+ use_reentrant=False,
332
+ )
333
+ else:
334
+ # middle
335
+ sample = torch.utils.checkpoint.checkpoint(
336
+ create_custom_forward(self.mid_block), sample, latent_embeds
337
+ )
338
+ sample = sample.to(upscale_dtype)
339
+
340
+ # up
341
+ for up_block in self.up_blocks:
342
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
343
+ else:
344
+ # middle
345
+ sample = self.mid_block(sample, latent_embeds)
346
+ sample = sample.to(upscale_dtype)
347
+
348
+ # up
349
+ for up_block in self.up_blocks:
350
+ sample = up_block(sample, latent_embeds)
351
+
352
+ # post-process
353
+ if latent_embeds is None:
354
+ sample = self.conv_norm_out(sample)
355
+ else:
356
+ sample = self.conv_norm_out(sample, latent_embeds)
357
+ sample = self.conv_act(sample)
358
+ sample = self.conv_out(sample)
359
+
360
+ return sample
361
+
362
+
363
+ class UpSample(nn.Module):
364
+ r"""
365
+ The `UpSample` layer of a variational autoencoder that upsamples its input.
366
+
367
+ Args:
368
+ in_channels (`int`, *optional*, defaults to 3):
369
+ The number of input channels.
370
+ out_channels (`int`, *optional*, defaults to 3):
371
+ The number of output channels.
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ in_channels: int,
377
+ out_channels: int,
378
+ ) -> None:
379
+ super().__init__()
380
+ self.in_channels = in_channels
381
+ self.out_channels = out_channels
382
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
383
+
384
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
385
+ r"""The forward method of the `UpSample` class."""
386
+ x = torch.relu(x)
387
+ x = self.deconv(x)
388
+ return x
389
+
390
+
391
+ class MaskConditionEncoder(nn.Module):
392
+ """
393
+ used in AsymmetricAutoencoderKL
394
+ """
395
+
396
+ def __init__(
397
+ self,
398
+ in_ch: int,
399
+ out_ch: int = 192,
400
+ res_ch: int = 768,
401
+ stride: int = 16,
402
+ ) -> None:
403
+ super().__init__()
404
+
405
+ channels = []
406
+ while stride > 1:
407
+ stride = stride // 2
408
+ in_ch_ = out_ch * 2
409
+ if out_ch > res_ch:
410
+ out_ch = res_ch
411
+ if stride == 1:
412
+ in_ch_ = res_ch
413
+ channels.append((in_ch_, out_ch))
414
+ out_ch *= 2
415
+
416
+ out_channels = []
417
+ for _in_ch, _out_ch in channels:
418
+ out_channels.append(_out_ch)
419
+ out_channels.append(channels[-1][0])
420
+
421
+ layers = []
422
+ in_ch_ = in_ch
423
+ for l in range(len(out_channels)):
424
+ out_ch_ = out_channels[l]
425
+ if l == 0 or l == 1:
426
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
427
+ else:
428
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
429
+ in_ch_ = out_ch_
430
+
431
+ self.layers = nn.Sequential(*layers)
432
+
433
+ def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
434
+ r"""The forward method of the `MaskConditionEncoder` class."""
435
+ out = {}
436
+ for l in range(len(self.layers)):
437
+ layer = self.layers[l]
438
+ x = layer(x)
439
+ out[str(tuple(x.shape))] = x
440
+ x = torch.relu(x)
441
+ return out
442
+
443
+
444
+ class MaskConditionDecoder(nn.Module):
445
+ r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
446
+ decoder with a conditioner on the mask and masked image.
447
+
448
+ Args:
449
+ in_channels (`int`, *optional*, defaults to 3):
450
+ The number of input channels.
451
+ out_channels (`int`, *optional*, defaults to 3):
452
+ The number of output channels.
453
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
454
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
455
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
456
+ The number of output channels for each block.
457
+ layers_per_block (`int`, *optional*, defaults to 2):
458
+ The number of layers per block.
459
+ norm_num_groups (`int`, *optional*, defaults to 32):
460
+ The number of groups for normalization.
461
+ act_fn (`str`, *optional*, defaults to `"silu"`):
462
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
463
+ norm_type (`str`, *optional*, defaults to `"group"`):
464
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
465
+ """
466
+
467
+ def __init__(
468
+ self,
469
+ in_channels: int = 3,
470
+ out_channels: int = 3,
471
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
472
+ block_out_channels: Tuple[int, ...] = (64,),
473
+ layers_per_block: int = 2,
474
+ norm_num_groups: int = 32,
475
+ act_fn: str = "silu",
476
+ norm_type: str = "group", # group, spatial
477
+ ):
478
+ super().__init__()
479
+ self.layers_per_block = layers_per_block
480
+
481
+ self.conv_in = nn.Conv2d(
482
+ in_channels,
483
+ block_out_channels[-1],
484
+ kernel_size=3,
485
+ stride=1,
486
+ padding=1,
487
+ )
488
+
489
+ self.up_blocks = nn.ModuleList([])
490
+
491
+ temb_channels = in_channels if norm_type == "spatial" else None
492
+
493
+ # mid
494
+ self.mid_block = UNetMidBlock2D(
495
+ in_channels=block_out_channels[-1],
496
+ resnet_eps=1e-6,
497
+ resnet_act_fn=act_fn,
498
+ output_scale_factor=1,
499
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
500
+ attention_head_dim=block_out_channels[-1],
501
+ resnet_groups=norm_num_groups,
502
+ temb_channels=temb_channels,
503
+ )
504
+
505
+ # up
506
+ reversed_block_out_channels = list(reversed(block_out_channels))
507
+ output_channel = reversed_block_out_channels[0]
508
+ for i, up_block_type in enumerate(up_block_types):
509
+ prev_output_channel = output_channel
510
+ output_channel = reversed_block_out_channels[i]
511
+
512
+ is_final_block = i == len(block_out_channels) - 1
513
+
514
+ up_block = get_up_block(
515
+ up_block_type,
516
+ num_layers=self.layers_per_block + 1,
517
+ in_channels=prev_output_channel,
518
+ out_channels=output_channel,
519
+ prev_output_channel=None,
520
+ add_upsample=not is_final_block,
521
+ resnet_eps=1e-6,
522
+ resnet_act_fn=act_fn,
523
+ resnet_groups=norm_num_groups,
524
+ attention_head_dim=output_channel,
525
+ temb_channels=temb_channels,
526
+ resnet_time_scale_shift=norm_type,
527
+ )
528
+ self.up_blocks.append(up_block)
529
+ prev_output_channel = output_channel
530
+
531
+ # condition encoder
532
+ self.condition_encoder = MaskConditionEncoder(
533
+ in_ch=out_channels,
534
+ out_ch=block_out_channels[0],
535
+ res_ch=block_out_channels[-1],
536
+ )
537
+
538
+ # out
539
+ if norm_type == "spatial":
540
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
541
+ else:
542
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
543
+ self.conv_act = nn.SiLU()
544
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
545
+
546
+ self.gradient_checkpointing = False
547
+
548
+ def forward(
549
+ self,
550
+ z: torch.Tensor,
551
+ image: Optional[torch.Tensor] = None,
552
+ mask: Optional[torch.Tensor] = None,
553
+ latent_embeds: Optional[torch.Tensor] = None,
554
+ ) -> torch.Tensor:
555
+ r"""The forward method of the `MaskConditionDecoder` class."""
556
+ sample = z
557
+ sample = self.conv_in(sample)
558
+
559
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
560
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
561
+
562
+ def create_custom_forward(module):
563
+ def custom_forward(*inputs):
564
+ return module(*inputs)
565
+
566
+ return custom_forward
567
+
568
+ if is_torch_version(">=", "1.11.0"):
569
+ # middle
570
+ sample = torch.utils.checkpoint.checkpoint(
571
+ create_custom_forward(self.mid_block),
572
+ sample,
573
+ latent_embeds,
574
+ use_reentrant=False,
575
+ )
576
+ sample = sample.to(upscale_dtype)
577
+
578
+ # condition encoder
579
+ if image is not None and mask is not None:
580
+ masked_image = (1 - mask) * image
581
+ im_x = torch.utils.checkpoint.checkpoint(
582
+ create_custom_forward(self.condition_encoder),
583
+ masked_image,
584
+ mask,
585
+ use_reentrant=False,
586
+ )
587
+
588
+ # up
589
+ for up_block in self.up_blocks:
590
+ if image is not None and mask is not None:
591
+ sample_ = im_x[str(tuple(sample.shape))]
592
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
593
+ sample = sample * mask_ + sample_ * (1 - mask_)
594
+ sample = torch.utils.checkpoint.checkpoint(
595
+ create_custom_forward(up_block),
596
+ sample,
597
+ latent_embeds,
598
+ use_reentrant=False,
599
+ )
600
+ if image is not None and mask is not None:
601
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
602
+ else:
603
+ # middle
604
+ sample = torch.utils.checkpoint.checkpoint(
605
+ create_custom_forward(self.mid_block), sample, latent_embeds
606
+ )
607
+ sample = sample.to(upscale_dtype)
608
+
609
+ # condition encoder
610
+ if image is not None and mask is not None:
611
+ masked_image = (1 - mask) * image
612
+ im_x = torch.utils.checkpoint.checkpoint(
613
+ create_custom_forward(self.condition_encoder),
614
+ masked_image,
615
+ mask,
616
+ )
617
+
618
+ # up
619
+ for up_block in self.up_blocks:
620
+ if image is not None and mask is not None:
621
+ sample_ = im_x[str(tuple(sample.shape))]
622
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
623
+ sample = sample * mask_ + sample_ * (1 - mask_)
624
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
625
+ if image is not None and mask is not None:
626
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
627
+ else:
628
+ # middle
629
+ sample = self.mid_block(sample, latent_embeds)
630
+ sample = sample.to(upscale_dtype)
631
+
632
+ # condition encoder
633
+ if image is not None and mask is not None:
634
+ masked_image = (1 - mask) * image
635
+ im_x = self.condition_encoder(masked_image, mask)
636
+
637
+ # up
638
+ for up_block in self.up_blocks:
639
+ if image is not None and mask is not None:
640
+ sample_ = im_x[str(tuple(sample.shape))]
641
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
642
+ sample = sample * mask_ + sample_ * (1 - mask_)
643
+ sample = up_block(sample, latent_embeds)
644
+ if image is not None and mask is not None:
645
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
646
+
647
+ # post-process
648
+ if latent_embeds is None:
649
+ sample = self.conv_norm_out(sample)
650
+ else:
651
+ sample = self.conv_norm_out(sample, latent_embeds)
652
+ sample = self.conv_act(sample)
653
+ sample = self.conv_out(sample)
654
+
655
+ return sample
656
+
657
+
658
+ class VectorQuantizer(nn.Module):
659
+ """
660
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
661
+ multiplications and allows for post-hoc remapping of indices.
662
+ """
663
+
664
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
665
+ # backwards compatibility we use the buggy version by default, but you can
666
+ # specify legacy=False to fix it.
667
+ def __init__(
668
+ self,
669
+ n_e: int,
670
+ vq_embed_dim: int,
671
+ beta: float,
672
+ remap=None,
673
+ unknown_index: str = "random",
674
+ sane_index_shape: bool = False,
675
+ legacy: bool = True,
676
+ ):
677
+ super().__init__()
678
+ self.n_e = n_e
679
+ self.vq_embed_dim = vq_embed_dim
680
+ self.beta = beta
681
+ self.legacy = legacy
682
+
683
+ self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
684
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
685
+
686
+ self.remap = remap
687
+ if self.remap is not None:
688
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
689
+ self.used: torch.Tensor
690
+ self.re_embed = self.used.shape[0]
691
+ self.unknown_index = unknown_index # "random" or "extra" or integer
692
+ if self.unknown_index == "extra":
693
+ self.unknown_index = self.re_embed
694
+ self.re_embed = self.re_embed + 1
695
+ print(
696
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
697
+ f"Using {self.unknown_index} for unknown indices."
698
+ )
699
+ else:
700
+ self.re_embed = n_e
701
+
702
+ self.sane_index_shape = sane_index_shape
703
+
704
+ def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
705
+ ishape = inds.shape
706
+ assert len(ishape) > 1
707
+ inds = inds.reshape(ishape[0], -1)
708
+ used = self.used.to(inds)
709
+ match = (inds[:, :, None] == used[None, None, ...]).long()
710
+ new = match.argmax(-1)
711
+ unknown = match.sum(2) < 1
712
+ if self.unknown_index == "random":
713
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
714
+ else:
715
+ new[unknown] = self.unknown_index
716
+ return new.reshape(ishape)
717
+
718
+ def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
719
+ ishape = inds.shape
720
+ assert len(ishape) > 1
721
+ inds = inds.reshape(ishape[0], -1)
722
+ used = self.used.to(inds)
723
+ if self.re_embed > self.used.shape[0]: # extra token
724
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
725
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
726
+ return back.reshape(ishape)
727
+
728
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
729
+ # reshape z -> (batch, height, width, channel) and flatten
730
+ z = z.permute(0, 2, 3, 1).contiguous()
731
+ z_flattened = z.view(-1, self.vq_embed_dim)
732
+
733
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
734
+ min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
735
+
736
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
737
+ perplexity = None
738
+ min_encodings = None
739
+
740
+ # compute loss for embedding
741
+ if not self.legacy:
742
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
743
+ else:
744
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
745
+
746
+ # preserve gradients
747
+ z_q: torch.Tensor = z + (z_q - z).detach()
748
+
749
+ # reshape back to match original input shape
750
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
751
+
752
+ if self.remap is not None:
753
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
754
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
755
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
756
+
757
+ if self.sane_index_shape:
758
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
759
+
760
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
761
+
762
+ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
763
+ # shape specifying (batch, height, width, channel)
764
+ if self.remap is not None:
765
+ indices = indices.reshape(shape[0], -1) # add batch axis
766
+ indices = self.unmap_to_all(indices)
767
+ indices = indices.reshape(-1) # flatten again
768
+
769
+ # get quantized latent vectors
770
+ z_q: torch.Tensor = self.embedding(indices)
771
+
772
+ if shape is not None:
773
+ z_q = z_q.view(shape)
774
+ # reshape back to match original input shape
775
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
776
+
777
+ return z_q
778
+
779
+
780
+ class DiagonalGaussianDistribution(object):
781
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
782
+ self.parameters = parameters
783
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
784
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
785
+ self.deterministic = deterministic
786
+ self.std = torch.exp(0.5 * self.logvar)
787
+ self.var = torch.exp(self.logvar)
788
+ if self.deterministic:
789
+ self.var = self.std = torch.zeros_like(
790
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
791
+ )
792
+
793
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
794
+ # make sure sample is on the same device as the parameters and has same dtype
795
+ sample = randn_tensor(
796
+ self.mean.shape,
797
+ generator=generator,
798
+ device=self.parameters.device,
799
+ dtype=self.parameters.dtype,
800
+ )
801
+ x = self.mean + self.std * sample
802
+ return x
803
+
804
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
805
+ if self.deterministic:
806
+ return torch.Tensor([0.0])
807
+ else:
808
+ if other is None:
809
+ return 0.5 * torch.sum(
810
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
811
+ dim=[1, 2, 3],
812
+ )
813
+ else:
814
+ return 0.5 * torch.sum(
815
+ torch.pow(self.mean - other.mean, 2) / other.var
816
+ + self.var / other.var
817
+ - 1.0
818
+ - self.logvar
819
+ + other.logvar,
820
+ dim=[1, 2, 3],
821
+ )
822
+
823
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
824
+ if self.deterministic:
825
+ return torch.Tensor([0.0])
826
+ logtwopi = np.log(2.0 * np.pi)
827
+ return 0.5 * torch.sum(
828
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
829
+ dim=dims,
830
+ )
831
+
832
+ def mode(self) -> torch.Tensor:
833
+ return self.mean
834
+
835
+
836
+ class EncoderTiny(nn.Module):
837
+ r"""
838
+ The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
839
+
840
+ Args:
841
+ in_channels (`int`):
842
+ The number of input channels.
843
+ out_channels (`int`):
844
+ The number of output channels.
845
+ num_blocks (`Tuple[int, ...]`):
846
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
847
+ use.
848
+ block_out_channels (`Tuple[int, ...]`):
849
+ The number of output channels for each block.
850
+ act_fn (`str`):
851
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
852
+ """
853
+
854
+ def __init__(
855
+ self,
856
+ in_channels: int,
857
+ out_channels: int,
858
+ num_blocks: Tuple[int, ...],
859
+ block_out_channels: Tuple[int, ...],
860
+ act_fn: str,
861
+ ):
862
+ super().__init__()
863
+
864
+ layers = []
865
+ for i, num_block in enumerate(num_blocks):
866
+ num_channels = block_out_channels[i]
867
+
868
+ if i == 0:
869
+ layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
870
+ else:
871
+ layers.append(
872
+ nn.Conv2d(
873
+ num_channels,
874
+ num_channels,
875
+ kernel_size=3,
876
+ padding=1,
877
+ stride=2,
878
+ bias=False,
879
+ )
880
+ )
881
+
882
+ for _ in range(num_block):
883
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
884
+
885
+ layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
886
+
887
+ self.layers = nn.Sequential(*layers)
888
+ self.gradient_checkpointing = False
889
+
890
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
891
+ r"""The forward method of the `EncoderTiny` class."""
892
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
893
+
894
+ def create_custom_forward(module):
895
+ def custom_forward(*inputs):
896
+ return module(*inputs)
897
+
898
+ return custom_forward
899
+
900
+ if is_torch_version(">=", "1.11.0"):
901
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
902
+ else:
903
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
904
+
905
+ else:
906
+ # scale image from [-1, 1] to [0, 1] to match TAESD convention
907
+ x = self.layers(x.add(1).div(2))
908
+
909
+ return x
910
+
911
+
912
+ class DecoderTiny(nn.Module):
913
+ r"""
914
+ The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
915
+
916
+ Args:
917
+ in_channels (`int`):
918
+ The number of input channels.
919
+ out_channels (`int`):
920
+ The number of output channels.
921
+ num_blocks (`Tuple[int, ...]`):
922
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
923
+ use.
924
+ block_out_channels (`Tuple[int, ...]`):
925
+ The number of output channels for each block.
926
+ upsampling_scaling_factor (`int`):
927
+ The scaling factor to use for upsampling.
928
+ act_fn (`str`):
929
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
930
+ """
931
+
932
+ def __init__(
933
+ self,
934
+ in_channels: int,
935
+ out_channels: int,
936
+ num_blocks: Tuple[int, ...],
937
+ block_out_channels: Tuple[int, ...],
938
+ upsampling_scaling_factor: int,
939
+ act_fn: str,
940
+ upsample_fn: str,
941
+ ):
942
+ super().__init__()
943
+
944
+ layers = [
945
+ nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
946
+ get_activation(act_fn),
947
+ ]
948
+
949
+ for i, num_block in enumerate(num_blocks):
950
+ is_final_block = i == (len(num_blocks) - 1)
951
+ num_channels = block_out_channels[i]
952
+
953
+ for _ in range(num_block):
954
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
955
+
956
+ if not is_final_block:
957
+ layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))
958
+
959
+ conv_out_channel = num_channels if not is_final_block else out_channels
960
+ layers.append(
961
+ nn.Conv2d(
962
+ num_channels,
963
+ conv_out_channel,
964
+ kernel_size=3,
965
+ padding=1,
966
+ bias=is_final_block,
967
+ )
968
+ )
969
+
970
+ self.layers = nn.Sequential(*layers)
971
+ self.gradient_checkpointing = False
972
+
973
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
974
+ r"""The forward method of the `DecoderTiny` class."""
975
+ # Clamp.
976
+ x = torch.tanh(x / 3) * 3
977
+
978
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
979
+
980
+ def create_custom_forward(module):
981
+ def custom_forward(*inputs):
982
+ return module(*inputs)
983
+
984
+ return custom_forward
985
+
986
+ if is_torch_version(">=", "1.11.0"):
987
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
988
+ else:
989
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
990
+
991
+ else:
992
+ x = self.layers(x)
993
+
994
+ # scale image from [0, 1] to [-1, 1] to match diffusers convention
995
+ return x.mul(2).sub(1)