Diffusers
Safetensors
PixCellPipeline
AlexGraikos commited on
Commit
7147c53
·
verified ·
1 Parent(s): 0f2387c

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test_image.jpg filter=lfs diff=lfs merge=lfs -text
model_index.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PixCellPipeline",
3
+ "_diffusers_version": "0.32.2",
4
+ "transformer": [
5
+ "pixcell_transformer_2d",
6
+ "PixCellTransformer2DModel"
7
+ ],
8
+ "vae": [
9
+ "diffusers",
10
+ "AutoencoderKL"
11
+ ],
12
+ "scheduler": [
13
+ "diffusers",
14
+ "DPMSolverMultistepScheduler"
15
+ ]
16
+ }
pipeline.py ADDED
@@ -0,0 +1,1122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import urllib.parse as ul
17
+ from typing import Callable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+ from diffusers.image_processor import PixArtImageProcessor
23
+ from diffusers.models import AutoencoderKL
24
+ from diffusers.schedulers import DPMSolverMultistepScheduler
25
+ from diffusers.utils import (
26
+ BACKENDS_MAPPING,
27
+ deprecate,
28
+ logging,
29
+ replace_example_docstring,
30
+ )
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
33
+
34
+ from pixcell_transformer_2d import PixCellTransformer2DModel
35
+
36
+
37
+ from typing import Any, Dict, Optional, Union
38
+
39
+
40
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
41
+ from diffusers.utils import is_torch_version, logging
42
+ from diffusers.models.attention import BasicTransformerBlock
43
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
44
+ from diffusers.models.embeddings import PatchEmbed
45
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
46
+ from diffusers.models.modeling_utils import ModelMixin
47
+ from diffusers.models.normalization import AdaLayerNormSingle
48
+
49
+ from typing import List, Optional, Tuple, Union
50
+
51
+ import numpy as np
52
+ import torch
53
+ import torch.nn.functional as F
54
+ from torch import nn
55
+
56
+ from diffusers.models.activations import deprecate, FP32SiLU
57
+
58
+
59
+ def pixcell_get_2d_sincos_pos_embed(
60
+ embed_dim,
61
+ grid_size,
62
+ cls_token=False,
63
+ extra_tokens=0,
64
+ interpolation_scale=1.0,
65
+ base_size=16,
66
+ device: Optional[torch.device] = None,
67
+ phase=0,
68
+ output_type: str = "np",
69
+ ):
70
+ """
71
+ Creates 2D sinusoidal positional embeddings.
72
+
73
+ Args:
74
+ embed_dim (`int`):
75
+ The embedding dimension.
76
+ grid_size (`int`):
77
+ The size of the grid height and width.
78
+ cls_token (`bool`, defaults to `False`):
79
+ Whether or not to add a classification token.
80
+ extra_tokens (`int`, defaults to `0`):
81
+ The number of extra tokens to add.
82
+ interpolation_scale (`float`, defaults to `1.0`):
83
+ The scale of the interpolation.
84
+
85
+ Returns:
86
+ pos_embed (`torch.Tensor`):
87
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
88
+ embed_dim]` if using cls_token
89
+ """
90
+ if output_type == "np":
91
+ deprecation_message = (
92
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
93
+ " `from_numpy` is no longer required."
94
+ " Pass `output_type='pt' to use the new version now."
95
+ )
96
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
97
+ raise ValueError("Not supported")
98
+ if isinstance(grid_size, int):
99
+ grid_size = (grid_size, grid_size)
100
+
101
+ grid_h = (
102
+ torch.arange(grid_size[0], device=device, dtype=torch.float32)
103
+ / (grid_size[0] / base_size)
104
+ / interpolation_scale
105
+ )
106
+ grid_w = (
107
+ torch.arange(grid_size[1], device=device, dtype=torch.float32)
108
+ / (grid_size[1] / base_size)
109
+ / interpolation_scale
110
+ )
111
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
112
+ grid = torch.stack(grid, dim=0)
113
+
114
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
115
+ pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type)
116
+ if cls_token and extra_tokens > 0:
117
+ pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
118
+ return pos_embed
119
+
120
+
121
+ def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"):
122
+ r"""
123
+ This function generates 2D sinusoidal positional embeddings from a grid.
124
+
125
+ Args:
126
+ embed_dim (`int`): The embedding dimension.
127
+ grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
128
+
129
+ Returns:
130
+ `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
131
+ """
132
+ if output_type == "np":
133
+ deprecation_message = (
134
+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
135
+ " `from_numpy` is no longer required."
136
+ " Pass `output_type='pt' to use the new version now."
137
+ )
138
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
139
+ raise ValueError("Not supported")
140
+ if embed_dim % 2 != 0:
141
+ raise ValueError("embed_dim must be divisible by 2")
142
+
143
+ # use half of dimensions to encode grid_h
144
+ emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) # (H*W, D/2)
145
+ emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) # (H*W, D/2)
146
+
147
+ emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
148
+ return emb
149
+
150
+
151
+ def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"):
152
+ """
153
+ This function generates 1D positional embeddings from a grid.
154
+
155
+ Args:
156
+ embed_dim (`int`): The embedding dimension `D`
157
+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
158
+
159
+ Returns:
160
+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
161
+ """
162
+ if output_type == "np":
163
+ deprecation_message = (
164
+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
165
+ " `from_numpy` is no longer required."
166
+ " Pass `output_type='pt' to use the new version now."
167
+ )
168
+ deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
169
+ raise ValueError("Not supported")
170
+ if embed_dim % 2 != 0:
171
+ raise ValueError("embed_dim must be divisible by 2")
172
+
173
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
174
+ omega /= embed_dim / 2.0
175
+ omega = 1.0 / 10000**omega # (D/2,)
176
+
177
+ pos = pos.reshape(-1) + phase # (M,)
178
+ out = torch.outer(pos, omega) # (M, D/2), outer product
179
+
180
+ emb_sin = torch.sin(out) # (M, D/2)
181
+ emb_cos = torch.cos(out) # (M, D/2)
182
+
183
+ emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
184
+ return emb
185
+
186
+
187
+ class PixcellUNIProjection(nn.Module):
188
+ """
189
+ Projects UNI embeddings. Also handles dropout for classifier-free guidance.
190
+
191
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
192
+ """
193
+
194
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1):
195
+ super().__init__()
196
+ if out_features is None:
197
+ out_features = hidden_size
198
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
199
+ if act_fn == "gelu_tanh":
200
+ self.act_1 = nn.GELU(approximate="tanh")
201
+ elif act_fn == "silu":
202
+ self.act_1 = nn.SiLU()
203
+ elif act_fn == "silu_fp32":
204
+ self.act_1 = FP32SiLU()
205
+ else:
206
+ raise ValueError(f"Unknown activation function: {act_fn}")
207
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
208
+
209
+ self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5))
210
+
211
+ def forward(self, caption):
212
+ hidden_states = self.linear_1(caption)
213
+ hidden_states = self.act_1(hidden_states)
214
+ hidden_states = self.linear_2(hidden_states)
215
+ return hidden_states
216
+
217
+ class UNIPosEmbed(nn.Module):
218
+ """
219
+ Adds positional embeddings to the UNI conditions.
220
+
221
+ Args:
222
+ height (`int`, defaults to `224`): The height of the image.
223
+ width (`int`, defaults to `224`): The width of the image.
224
+ patch_size (`int`, defaults to `16`): The size of the patches.
225
+ in_channels (`int`, defaults to `3`): The number of input channels.
226
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
227
+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
228
+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
229
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
230
+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
231
+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
232
+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ height=1,
238
+ width=1,
239
+ base_size=16,
240
+ embed_dim=768,
241
+ interpolation_scale=1,
242
+ pos_embed_type="sincos",
243
+ ):
244
+ super().__init__()
245
+
246
+ num_embeds = height*width
247
+ grid_size = int(num_embeds ** 0.5)
248
+
249
+ if pos_embed_type == "sincos":
250
+ y_pos_embed = pixcell_get_2d_sincos_pos_embed(
251
+ embed_dim,
252
+ grid_size,
253
+ base_size=base_size,
254
+ interpolation_scale=interpolation_scale,
255
+ output_type="pt",
256
+ phase = base_size // num_embeds
257
+ )
258
+ self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0))
259
+ else:
260
+ raise ValueError("`pos_embed_type` not supported")
261
+
262
+ def forward(self, uni_embeds):
263
+ return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype)
264
+
265
+
266
+
267
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
268
+
269
+
270
+ class PixCellTransformer2DModel(ModelMixin, ConfigMixin):
271
+ r"""
272
+ A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
273
+ https://arxiv.org/abs/2403.04692). Modified for the pathology domain.
274
+
275
+ Parameters:
276
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
277
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
278
+ in_channels (int, defaults to 4): The number of channels in the input.
279
+ out_channels (int, optional):
280
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
281
+ input.
282
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
283
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
284
+ norm_num_groups (int, optional, defaults to 32):
285
+ Number of groups for group normalization within Transformer blocks.
286
+ cross_attention_dim (int, optional):
287
+ The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
288
+ attention_bias (bool, optional, defaults to True):
289
+ Configure if the Transformer blocks' attention should contain a bias parameter.
290
+ sample_size (int, defaults to 128):
291
+ The width of the latent images. This parameter is fixed during training.
292
+ patch_size (int, defaults to 2):
293
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
294
+ activation_fn (str, optional, defaults to "gelu-approximate"):
295
+ Activation function to use in feed-forward networks within Transformer blocks.
296
+ num_embeds_ada_norm (int, optional, defaults to 1000):
297
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
298
+ inference.
299
+ upcast_attention (bool, optional, defaults to False):
300
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
301
+ norm_type (str, optional, defaults to "ada_norm_zero"):
302
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
303
+ norm_elementwise_affine (bool, optional, defaults to False):
304
+ If true, enables element-wise affine parameters in the normalization layers.
305
+ norm_eps (float, optional, defaults to 1e-6):
306
+ A small constant added to the denominator in normalization layers to prevent division by zero.
307
+ interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
308
+ use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
309
+ attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
310
+ caption_channels (int, optional, defaults to None):
311
+ Number of channels to use for projecting the caption embeddings.
312
+ use_linear_projection (bool, optional, defaults to False):
313
+ Deprecated argument. Will be removed in a future version.
314
+ num_vector_embeds (bool, optional, defaults to False):
315
+ Deprecated argument. Will be removed in a future version.
316
+ """
317
+
318
+ _supports_gradient_checkpointing = True
319
+ _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
320
+
321
+ @register_to_config
322
+ def __init__(
323
+ self,
324
+ num_attention_heads: int = 16,
325
+ attention_head_dim: int = 72,
326
+ in_channels: int = 4,
327
+ out_channels: Optional[int] = 8,
328
+ num_layers: int = 28,
329
+ dropout: float = 0.0,
330
+ norm_num_groups: int = 32,
331
+ cross_attention_dim: Optional[int] = 1152,
332
+ attention_bias: bool = True,
333
+ sample_size: int = 128,
334
+ patch_size: int = 2,
335
+ activation_fn: str = "gelu-approximate",
336
+ num_embeds_ada_norm: Optional[int] = 1000,
337
+ upcast_attention: bool = False,
338
+ norm_type: str = "ada_norm_single",
339
+ norm_elementwise_affine: bool = False,
340
+ norm_eps: float = 1e-6,
341
+ interpolation_scale: Optional[int] = None,
342
+ use_additional_conditions: Optional[bool] = None,
343
+ caption_channels: Optional[int] = None,
344
+ caption_num_tokens: int = 1,
345
+ attention_type: Optional[str] = "default",
346
+ ):
347
+ super().__init__()
348
+
349
+ # Validate inputs.
350
+ if norm_type != "ada_norm_single":
351
+ raise NotImplementedError(
352
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
353
+ )
354
+ elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
355
+ raise ValueError(
356
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
357
+ )
358
+
359
+ # Set some common variables used across the board.
360
+ self.attention_head_dim = attention_head_dim
361
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
362
+ self.out_channels = in_channels if out_channels is None else out_channels
363
+ if use_additional_conditions is None:
364
+ if sample_size == 128:
365
+ use_additional_conditions = True
366
+ else:
367
+ use_additional_conditions = False
368
+ self.use_additional_conditions = use_additional_conditions
369
+
370
+ self.gradient_checkpointing = False
371
+
372
+ # 2. Initialize the position embedding and transformer blocks.
373
+ self.height = self.config.sample_size
374
+ self.width = self.config.sample_size
375
+
376
+ interpolation_scale = (
377
+ self.config.interpolation_scale
378
+ if self.config.interpolation_scale is not None
379
+ else max(self.config.sample_size // 64, 1)
380
+ )
381
+ self.pos_embed = PatchEmbed(
382
+ height=self.config.sample_size,
383
+ width=self.config.sample_size,
384
+ patch_size=self.config.patch_size,
385
+ in_channels=self.config.in_channels,
386
+ embed_dim=self.inner_dim,
387
+ interpolation_scale=interpolation_scale,
388
+ )
389
+
390
+ self.transformer_blocks = nn.ModuleList(
391
+ [
392
+ BasicTransformerBlock(
393
+ self.inner_dim,
394
+ self.config.num_attention_heads,
395
+ self.config.attention_head_dim,
396
+ dropout=self.config.dropout,
397
+ cross_attention_dim=self.config.cross_attention_dim,
398
+ activation_fn=self.config.activation_fn,
399
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
400
+ attention_bias=self.config.attention_bias,
401
+ upcast_attention=self.config.upcast_attention,
402
+ norm_type=norm_type,
403
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
404
+ norm_eps=self.config.norm_eps,
405
+ attention_type=self.config.attention_type,
406
+ )
407
+ for _ in range(self.config.num_layers)
408
+ ]
409
+ )
410
+
411
+ # Initialize the positional embedding for the conditions for >1 UNI embeddings
412
+ if self.config.caption_num_tokens == 1:
413
+ self.y_pos_embed = None
414
+ else:
415
+ # 1:1 aspect ratio
416
+ self.uni_height = int(self.config.caption_num_tokens ** 0.5)
417
+ self.uni_width = int(self.config.caption_num_tokens ** 0.5)
418
+
419
+ self.y_pos_embed = UNIPosEmbed(
420
+ height=self.uni_height,
421
+ width=self.uni_width,
422
+ base_size=self.config.sample_size // self.config.patch_size,
423
+ embed_dim=self.config.caption_channels,
424
+ interpolation_scale=2, # Should this be fixed?
425
+ pos_embed_type="sincos", # This is fixed
426
+ )
427
+
428
+ # 3. Output blocks.
429
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
430
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
431
+ self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
432
+
433
+ self.adaln_single = AdaLayerNormSingle(
434
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
435
+ )
436
+ self.caption_projection = None
437
+ if self.config.caption_channels is not None:
438
+ self.caption_projection = PixcellUNIProjection(
439
+ in_features=self.config.caption_channels, hidden_size=self.inner_dim, num_tokens=self.config.caption_num_tokens,
440
+ )
441
+
442
+ def _set_gradient_checkpointing(self, module, value=False):
443
+ if hasattr(module, "gradient_checkpointing"):
444
+ module.gradient_checkpointing = value
445
+
446
+ @property
447
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
448
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
449
+ r"""
450
+ Returns:
451
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
452
+ indexed by its weight name.
453
+ """
454
+ # set recursively
455
+ processors = {}
456
+
457
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
458
+ if hasattr(module, "get_processor"):
459
+ processors[f"{name}.processor"] = module.get_processor()
460
+
461
+ for sub_name, child in module.named_children():
462
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
463
+
464
+ return processors
465
+
466
+ for name, module in self.named_children():
467
+ fn_recursive_add_processors(name, module, processors)
468
+
469
+ return processors
470
+
471
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
472
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
473
+ r"""
474
+ Sets the attention processor to use to compute attention.
475
+
476
+ Parameters:
477
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
478
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
479
+ for **all** `Attention` layers.
480
+
481
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
482
+ processor. This is strongly recommended when setting trainable attention processors.
483
+
484
+ """
485
+ count = len(self.attn_processors.keys())
486
+
487
+ if isinstance(processor, dict) and len(processor) != count:
488
+ raise ValueError(
489
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
490
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
491
+ )
492
+
493
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
494
+ if hasattr(module, "set_processor"):
495
+ if not isinstance(processor, dict):
496
+ module.set_processor(processor)
497
+ else:
498
+ module.set_processor(processor.pop(f"{name}.processor"))
499
+
500
+ for sub_name, child in module.named_children():
501
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
502
+
503
+ for name, module in self.named_children():
504
+ fn_recursive_attn_processor(name, module, processor)
505
+
506
+ def set_default_attn_processor(self):
507
+ """
508
+ Disables custom attention processors and sets the default attention implementation.
509
+
510
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
511
+ """
512
+ self.set_attn_processor(AttnProcessor())
513
+
514
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
515
+ def fuse_qkv_projections(self):
516
+ """
517
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
518
+ are fused. For cross-attention modules, key and value projection matrices are fused.
519
+
520
+ <Tip warning={true}>
521
+
522
+ This API is 🧪 experimental.
523
+
524
+ </Tip>
525
+ """
526
+ self.original_attn_processors = None
527
+
528
+ for _, attn_processor in self.attn_processors.items():
529
+ if "Added" in str(attn_processor.__class__.__name__):
530
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
531
+
532
+ self.original_attn_processors = self.attn_processors
533
+
534
+ for module in self.modules():
535
+ if isinstance(module, Attention):
536
+ module.fuse_projections(fuse=True)
537
+
538
+ self.set_attn_processor(FusedAttnProcessor2_0())
539
+
540
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
541
+ def unfuse_qkv_projections(self):
542
+ """Disables the fused QKV projection if enabled.
543
+
544
+ <Tip warning={true}>
545
+
546
+ This API is 🧪 experimental.
547
+
548
+ </Tip>
549
+
550
+ """
551
+ if self.original_attn_processors is not None:
552
+ self.set_attn_processor(self.original_attn_processors)
553
+
554
+ def forward(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ encoder_hidden_states: Optional[torch.Tensor] = None,
558
+ timestep: Optional[torch.LongTensor] = None,
559
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
560
+ cross_attention_kwargs: Dict[str, Any] = None,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ encoder_attention_mask: Optional[torch.Tensor] = None,
563
+ return_dict: bool = True,
564
+ ):
565
+ """
566
+ The [`PixCellTransformer2DModel`] forward method.
567
+
568
+ Args:
569
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
570
+ Input `hidden_states`.
571
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
572
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
573
+ self-attention.
574
+ timestep (`torch.LongTensor`, *optional*):
575
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
576
+ added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
577
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
578
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
579
+ `self.processor` in
580
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
581
+ attention_mask ( `torch.Tensor`, *optional*):
582
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
583
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
584
+ negative values to the attention scores corresponding to "discard" tokens.
585
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
586
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
587
+
588
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
589
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
590
+
591
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
592
+ above. This bias will be added to the cross-attention scores.
593
+ return_dict (`bool`, *optional*, defaults to `True`):
594
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
595
+ tuple.
596
+
597
+ Returns:
598
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
599
+ `tuple` where the first element is the sample tensor.
600
+ """
601
+ if self.use_additional_conditions and added_cond_kwargs is None:
602
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
603
+
604
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
605
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
606
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
607
+ # expects mask of shape:
608
+ # [batch, key_tokens]
609
+ # adds singleton query_tokens dimension:
610
+ # [batch, 1, key_tokens]
611
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
612
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
613
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
614
+ if attention_mask is not None and attention_mask.ndim == 2:
615
+ # assume that mask is expressed as:
616
+ # (1 = keep, 0 = discard)
617
+ # convert mask into a bias that can be added to attention scores:
618
+ # (keep = +0, discard = -10000.0)
619
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
620
+ attention_mask = attention_mask.unsqueeze(1)
621
+
622
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
623
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
624
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
625
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
626
+
627
+ # 1. Input
628
+ batch_size = hidden_states.shape[0]
629
+ height, width = (
630
+ hidden_states.shape[-2] // self.config.patch_size,
631
+ hidden_states.shape[-1] // self.config.patch_size,
632
+ )
633
+ hidden_states = self.pos_embed(hidden_states)
634
+
635
+ timestep, embedded_timestep = self.adaln_single(
636
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
637
+ )
638
+
639
+ if self.caption_projection is not None:
640
+ # Add positional embeddings to conditions if >1 UNI are given
641
+ if self.y_pos_embed is not None:
642
+ encoder_hidden_states = self.y_pos_embed(encoder_hidden_states)
643
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
644
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
645
+
646
+ # 2. Blocks
647
+ for block in self.transformer_blocks:
648
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
649
+
650
+ def create_custom_forward(module, return_dict=None):
651
+ def custom_forward(*inputs):
652
+ if return_dict is not None:
653
+ return module(*inputs, return_dict=return_dict)
654
+ else:
655
+ return module(*inputs)
656
+
657
+ return custom_forward
658
+
659
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
660
+ hidden_states = torch.utils.checkpoint.checkpoint(
661
+ create_custom_forward(block),
662
+ hidden_states,
663
+ attention_mask,
664
+ encoder_hidden_states,
665
+ encoder_attention_mask,
666
+ timestep,
667
+ cross_attention_kwargs,
668
+ None,
669
+ **ckpt_kwargs,
670
+ )
671
+ else:
672
+ hidden_states = block(
673
+ hidden_states,
674
+ attention_mask=attention_mask,
675
+ encoder_hidden_states=encoder_hidden_states,
676
+ encoder_attention_mask=encoder_attention_mask,
677
+ timestep=timestep,
678
+ cross_attention_kwargs=cross_attention_kwargs,
679
+ class_labels=None,
680
+ )
681
+
682
+ # 3. Output
683
+ shift, scale = (
684
+ self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
685
+ ).chunk(2, dim=1)
686
+ hidden_states = self.norm_out(hidden_states)
687
+ # Modulation
688
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
689
+ hidden_states = self.proj_out(hidden_states)
690
+ hidden_states = hidden_states.squeeze(1)
691
+
692
+ # unpatchify
693
+ hidden_states = hidden_states.reshape(
694
+ shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
695
+ )
696
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
697
+ output = hidden_states.reshape(
698
+ shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
699
+ )
700
+
701
+ if not return_dict:
702
+ return (output,)
703
+
704
+ return Transformer2DModelOutput(sample=output)
705
+
706
+
707
+
708
+
709
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
710
+
711
+
712
+ EXAMPLE_DOC_STRING = """
713
+ Examples:
714
+ ```py
715
+ >>> import torch
716
+ >>> from diffusers import PixCellSigmaPipeline
717
+
718
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too.
719
+ >>> pipe = PixArtSigmaPipeline.from_pretrained(
720
+ ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
721
+ ... )
722
+ >>> # Enable memory optimizations.
723
+ >>> # pipe.enable_model_cpu_offload()
724
+
725
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
726
+ >>> image = pipe(prompt).images[0]
727
+ ```
728
+ """
729
+
730
+
731
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
732
+ def retrieve_timesteps(
733
+ scheduler,
734
+ num_inference_steps: Optional[int] = None,
735
+ device: Optional[Union[str, torch.device]] = None,
736
+ timesteps: Optional[List[int]] = None,
737
+ sigmas: Optional[List[float]] = None,
738
+ **kwargs,
739
+ ):
740
+ r"""
741
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
742
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
743
+
744
+ Args:
745
+ scheduler (`SchedulerMixin`):
746
+ The scheduler to get timesteps from.
747
+ num_inference_steps (`int`):
748
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
749
+ must be `None`.
750
+ device (`str` or `torch.device`, *optional*):
751
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
752
+ timesteps (`List[int]`, *optional*):
753
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
754
+ `num_inference_steps` and `sigmas` must be `None`.
755
+ sigmas (`List[float]`, *optional*):
756
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
757
+ `num_inference_steps` and `timesteps` must be `None`.
758
+
759
+ Returns:
760
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
761
+ second element is the number of inference steps.
762
+ """
763
+ if timesteps is not None and sigmas is not None:
764
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
765
+ if timesteps is not None:
766
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
767
+ if not accepts_timesteps:
768
+ raise ValueError(
769
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
770
+ f" timestep schedules. Please check whether you are using the correct scheduler."
771
+ )
772
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
773
+ timesteps = scheduler.timesteps
774
+ num_inference_steps = len(timesteps)
775
+ elif sigmas is not None:
776
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
777
+ if not accept_sigmas:
778
+ raise ValueError(
779
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
780
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
781
+ )
782
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
783
+ timesteps = scheduler.timesteps
784
+ num_inference_steps = len(timesteps)
785
+ else:
786
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
787
+ timesteps = scheduler.timesteps
788
+ return timesteps, num_inference_steps
789
+
790
+
791
+ class PixCellPipeline(DiffusionPipeline):
792
+ r"""
793
+ Pipeline for SSL-to-image generation using PixCell.
794
+ """
795
+
796
+ model_cpu_offload_seq = "transformer->vae"
797
+
798
+ def __init__(
799
+ self,
800
+ vae: AutoencoderKL,
801
+ transformer: PixCellTransformer2DModel,
802
+ scheduler: DPMSolverMultistepScheduler,
803
+ ):
804
+ super().__init__()
805
+
806
+ self.register_modules(
807
+ vae=vae, transformer=transformer, scheduler=scheduler
808
+ )
809
+
810
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
811
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
812
+
813
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
814
+ def prepare_extra_step_kwargs(self, generator, eta):
815
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
816
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
817
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
818
+ # and should be between [0, 1]
819
+
820
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
821
+ extra_step_kwargs = {}
822
+ if accepts_eta:
823
+ extra_step_kwargs["eta"] = eta
824
+
825
+ # check if the scheduler accepts generator
826
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
827
+ if accepts_generator:
828
+ extra_step_kwargs["generator"] = generator
829
+ return extra_step_kwargs
830
+
831
+ def get_unconditional_embedding(self, batch_size=1):
832
+ # Unconditional embedding is learned
833
+ uncond = self.transformer.caption_projection.uncond_embedding.clone().tile(batch_size,1,1)
834
+ return uncond
835
+
836
+ # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
837
+ def check_inputs(
838
+ self,
839
+ height,
840
+ width,
841
+ callback_steps,
842
+ uni_embeds=None,
843
+ negative_uni_embeds=None,
844
+ guidance_scale=None,
845
+ ):
846
+ if height % 8 != 0 or width % 8 != 0:
847
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
848
+
849
+ if (callback_steps is None) or (
850
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
851
+ ):
852
+ raise ValueError(
853
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
854
+ f" {type(callback_steps)}."
855
+ )
856
+
857
+ if uni_embeds is None:
858
+ raise ValueError(
859
+ "Provide a UNI embedding `uni_embeds`."
860
+ )
861
+ elif len(uni_embeds.shape) != 3:
862
+ raise ValueError(
863
+ "UNI embedding given is not in (B,N,D)."
864
+ )
865
+ elif uni_embeds.shape[1] != self.transformer.config.caption_num_tokens:
866
+ raise ValueError(
867
+ f"Number of UNI embeddings must match the ones used in training ({self.transformer.config.caption_num_tokens})."
868
+ )
869
+ elif uni_embeds.shape[2] != self.transformer.config.caption_channels:
870
+ raise ValueError(
871
+ "UNI embedding given has incorrect dimenions."
872
+ )
873
+
874
+ if guidance_scale > 1.0:
875
+ if negative_uni_embeds is None:
876
+ raise ValueError(
877
+ "Provide a negative UNI embedding `negative_uni_embeds`."
878
+ )
879
+ elif len(negative_uni_embeds.shape) != 3:
880
+ raise ValueError(
881
+ "Negative UNI embedding given is not in (B,N,D)."
882
+ )
883
+ elif negative_uni_embeds.shape[1] != self.transformer.config.caption_num_tokens:
884
+ raise ValueError(
885
+ f"Number of negative UNI embeddings must match the ones used in training ({self.transformer.config.caption_num_tokens})."
886
+ )
887
+ elif negative_uni_embeds.shape[2] != self.transformer.config.caption_channels:
888
+ raise ValueError(
889
+ "Negative UNI embedding given has incorrect dimenions."
890
+ )
891
+
892
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
893
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
894
+ shape = (
895
+ batch_size,
896
+ num_channels_latents,
897
+ int(height) // self.vae_scale_factor,
898
+ int(width) // self.vae_scale_factor,
899
+ )
900
+ if isinstance(generator, list) and len(generator) != batch_size:
901
+ raise ValueError(
902
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
903
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
904
+ )
905
+
906
+ if latents is None:
907
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
908
+ else:
909
+ latents = latents.to(device)
910
+
911
+ # scale the initial noise by the standard deviation required by the scheduler
912
+ latents = latents * self.scheduler.init_noise_sigma
913
+ return latents
914
+
915
+ @torch.no_grad()
916
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
917
+ def __call__(
918
+ self,
919
+ num_inference_steps: int = 20,
920
+ timesteps: List[int] = None,
921
+ sigmas: List[float] = None,
922
+ guidance_scale: float = 1.5,
923
+ num_images_per_prompt: Optional[int] = 1,
924
+ height: Optional[int] = None,
925
+ width: Optional[int] = None,
926
+ eta: float = 0.0,
927
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
928
+ latents: Optional[torch.Tensor] = None,
929
+ uni_embeds: Optional[torch.Tensor] = None,
930
+ negative_uni_embeds: Optional[torch.Tensor] = None,
931
+ output_type: Optional[str] = "pil",
932
+ return_dict: bool = True,
933
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
934
+ callback_steps: int = 1,
935
+ **kwargs,
936
+ ) -> Union[ImagePipelineOutput, Tuple]:
937
+ """
938
+ Function invoked when calling the pipeline for generation.
939
+
940
+ Args:
941
+ num_inference_steps (`int`, *optional*, defaults to 100):
942
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
943
+ expense of slower inference.
944
+ timesteps (`List[int]`, *optional*):
945
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
946
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
947
+ passed will be used. Must be in descending order.
948
+ sigmas (`List[float]`, *optional*):
949
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
950
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
951
+ will be used.
952
+ guidance_scale (`float`, *optional*, defaults to 4.5):
953
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
954
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
955
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
956
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
957
+ usually at the expense of lower image quality.
958
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
959
+ The number of images to generate per prompt.
960
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
961
+ The height in pixels of the generated image.
962
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
963
+ The width in pixels of the generated image.
964
+ eta (`float`, *optional*, defaults to 0.0):
965
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
966
+ [`schedulers.DDIMScheduler`], will be ignored for others.
967
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
968
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
969
+ to make generation deterministic.
970
+ latents (`torch.Tensor`, *optional*):
971
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
972
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
973
+ tensor will ge generated by sampling using the supplied random `generator`.
974
+ uni_embeds (`torch.Tensor`, *optional*):
975
+ Pre-generated UNI embeddings.
976
+ negative_uni_embeds (`torch.Tensor`, *optional*):
977
+ Pre-generated negative UNI embeddings.
978
+ output_type (`str`, *optional*, defaults to `"pil"`):
979
+ The output format of the generate image. Choose between
980
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
981
+ return_dict (`bool`, *optional*, defaults to `True`):
982
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
983
+ callback (`Callable`, *optional*):
984
+ A function that will be called every `callback_steps` steps during inference. The function will be
985
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
986
+ callback_steps (`int`, *optional*, defaults to 1):
987
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
988
+ called at every step.
989
+
990
+ Examples:
991
+
992
+ Returns:
993
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
994
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
995
+ returned where the first element is a list with the generated images
996
+ """
997
+ # 1. Check inputs. Raise error if not correct
998
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
999
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
1000
+
1001
+ self.check_inputs(
1002
+ height,
1003
+ width,
1004
+ callback_steps,
1005
+ uni_embeds,
1006
+ negative_uni_embeds,
1007
+ guidance_scale,
1008
+ )
1009
+
1010
+ # 2. Default height and width to transformer
1011
+ batch_size = uni_embeds.shape[0]
1012
+
1013
+ device = self._execution_device
1014
+
1015
+ # 3. Handle UNI conditioning
1016
+
1017
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1018
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1019
+ # corresponds to doing no classifier free guidance.
1020
+ do_classifier_free_guidance = guidance_scale > 1.0
1021
+
1022
+ uni_embeds = uni_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1023
+ if do_classifier_free_guidance:
1024
+ negative_uni_embeds = negative_uni_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1025
+ uni_embeds = torch.cat([negative_uni_embeds, uni_embeds], dim=0)
1026
+
1027
+
1028
+ # 4. Prepare timesteps
1029
+ timesteps, num_inference_steps = retrieve_timesteps(
1030
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1031
+ )
1032
+
1033
+ # 5. Prepare latents.
1034
+ latent_channels = self.transformer.config.in_channels
1035
+ latents = self.prepare_latents(
1036
+ batch_size * num_images_per_prompt,
1037
+ latent_channels,
1038
+ height,
1039
+ width,
1040
+ uni_embeds.dtype,
1041
+ device,
1042
+ generator,
1043
+ latents,
1044
+ )
1045
+
1046
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1047
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1048
+
1049
+ added_cond_kwargs = {}
1050
+
1051
+ # 7. Denoising loop
1052
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1053
+
1054
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1055
+ for i, t in enumerate(timesteps):
1056
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1057
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1058
+
1059
+ current_timestep = t
1060
+ if not torch.is_tensor(current_timestep):
1061
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1062
+ # This would be a good case for the `match` statement (Python 3.10+)
1063
+ is_mps = latent_model_input.device.type == "mps"
1064
+ if isinstance(current_timestep, float):
1065
+ dtype = torch.float32 if is_mps else torch.float64
1066
+ else:
1067
+ dtype = torch.int32 if is_mps else torch.int64
1068
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
1069
+ elif len(current_timestep.shape) == 0:
1070
+ current_timestep = current_timestep[None].to(latent_model_input.device)
1071
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1072
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
1073
+
1074
+ # predict noise model_output
1075
+ noise_pred = self.transformer(
1076
+ latent_model_input,
1077
+ encoder_hidden_states=uni_embeds,
1078
+ timestep=current_timestep,
1079
+ added_cond_kwargs=added_cond_kwargs,
1080
+ return_dict=False,
1081
+ )[0]
1082
+
1083
+ # perform guidance
1084
+ if do_classifier_free_guidance:
1085
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1086
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1087
+
1088
+ # learned sigma
1089
+ if self.transformer.config.out_channels // 2 == latent_channels:
1090
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
1091
+ else:
1092
+ noise_pred = noise_pred
1093
+
1094
+ # compute previous image: x_t -> x_t-1
1095
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1096
+
1097
+ # call the callback, if provided
1098
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1099
+ progress_bar.update()
1100
+ if callback is not None and i % callback_steps == 0:
1101
+ step_idx = i // getattr(self.scheduler, "order", 1)
1102
+ callback(step_idx, t, latents)
1103
+
1104
+ if not output_type == "latent":
1105
+ vae_scale = self.vae.config.scaling_factor
1106
+ vae_shift = getattr(self.vae.config, "shift_factor", 0)
1107
+
1108
+ image = self.vae.decode((latents / vae_scale) + vae_shift, return_dict=False)[0]
1109
+
1110
+ else:
1111
+ image = latents
1112
+
1113
+ if not output_type == "latent":
1114
+ image = self.image_processor.postprocess(image, output_type=output_type)
1115
+
1116
+ # Offload all models
1117
+ self.maybe_free_model_hooks()
1118
+
1119
+ if not return_dict:
1120
+ return (image,)
1121
+
1122
+ return ImagePipelineOutput(images=image)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DPMSolverMultistepScheduler",
3
+ "_diffusers_version": "0.32.2",
4
+ "algorithm_type": "dpmsolver++",
5
+ "beta_end": 0.02,
6
+ "beta_schedule": "linear",
7
+ "beta_start": 0.0001,
8
+ "dynamic_thresholding_ratio": 0.995,
9
+ "euler_at_final": false,
10
+ "final_sigmas_type": "zero",
11
+ "flow_shift": 1.0,
12
+ "lambda_min_clipped": -Infinity,
13
+ "lower_order_final": true,
14
+ "num_train_timesteps": 1000,
15
+ "prediction_type": "epsilon",
16
+ "rescale_betas_zero_snr": false,
17
+ "sample_max_value": 1.0,
18
+ "solver_order": 2,
19
+ "solver_type": "midpoint",
20
+ "steps_offset": 0,
21
+ "thresholding": false,
22
+ "timestep_spacing": "linspace",
23
+ "trained_betas": null,
24
+ "use_beta_sigmas": false,
25
+ "use_exponential_sigmas": false,
26
+ "use_flow_sigmas": false,
27
+ "use_karras_sigmas": false,
28
+ "use_lu_lambdas": false,
29
+ "variance_type": null
30
+ }
test_image.jpg ADDED

Git LFS Details

  • SHA256: a0d3425c9212ca34a6c7478d392c09613c7350b3c52aa9c489ef72288dd55e76
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
transformer/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PixCellTransformer2DModel",
3
+ "_diffusers_version": "0.32.2",
4
+ "_name_or_path": "pixart_256/transformer",
5
+ "activation_fn": "gelu-approximate",
6
+ "attention_bias": true,
7
+ "attention_head_dim": 72,
8
+ "attention_type": "default",
9
+ "caption_channels": 1536,
10
+ "caption_num_tokens": 1,
11
+ "cross_attention_dim": 1152,
12
+ "double_self_attention": false,
13
+ "dropout": 0.0,
14
+ "in_channels": 16,
15
+ "interpolation_scale": 0.5,
16
+ "norm_elementwise_affine": false,
17
+ "norm_eps": 1e-06,
18
+ "norm_num_groups": 32,
19
+ "norm_type": "ada_norm_single",
20
+ "num_attention_heads": 16,
21
+ "num_embeds_ada_norm": 1000,
22
+ "num_layers": 28,
23
+ "num_vector_embeds": null,
24
+ "only_cross_attention": false,
25
+ "out_channels": 32,
26
+ "patch_size": 2,
27
+ "sample_size": 32,
28
+ "upcast_attention": false,
29
+ "use_additional_conditions": false,
30
+ "use_linear_projection": false
31
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2897bd89e90037faf71c16ba6ce87165cb7c8509cafdab6aa824c3ea827cbe8
3
+ size 2432366184
transformer/pixcell_transformer_2d.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Union
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.utils import is_torch_version, logging
21
+ from diffusers.models.attention import BasicTransformerBlock
22
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
+ from diffusers.models.embeddings import PatchEmbed
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormSingle
27
+ from diffusers.models.activations import deprecate, FP32SiLU
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ # PixCell UNI conditioning
34
+ def pixcell_get_2d_sincos_pos_embed(
35
+ embed_dim,
36
+ grid_size,
37
+ cls_token=False,
38
+ extra_tokens=0,
39
+ interpolation_scale=1.0,
40
+ base_size=16,
41
+ device: Optional[torch.device] = None,
42
+ phase=0,
43
+ output_type: str = "np",
44
+ ):
45
+ """
46
+ Creates 2D sinusoidal positional embeddings.
47
+
48
+ Args:
49
+ embed_dim (`int`):
50
+ The embedding dimension.
51
+ grid_size (`int`):
52
+ The size of the grid height and width.
53
+ cls_token (`bool`, defaults to `False`):
54
+ Whether or not to add a classification token.
55
+ extra_tokens (`int`, defaults to `0`):
56
+ The number of extra tokens to add.
57
+ interpolation_scale (`float`, defaults to `1.0`):
58
+ The scale of the interpolation.
59
+
60
+ Returns:
61
+ pos_embed (`torch.Tensor`):
62
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
63
+ embed_dim]` if using cls_token
64
+ """
65
+ if output_type == "np":
66
+ deprecation_message = (
67
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
68
+ " `from_numpy` is no longer required."
69
+ " Pass `output_type='pt' to use the new version now."
70
+ )
71
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
72
+ raise ValueError("Not supported")
73
+ if isinstance(grid_size, int):
74
+ grid_size = (grid_size, grid_size)
75
+
76
+ grid_h = (
77
+ torch.arange(grid_size[0], device=device, dtype=torch.float32)
78
+ / (grid_size[0] / base_size)
79
+ / interpolation_scale
80
+ )
81
+ grid_w = (
82
+ torch.arange(grid_size[1], device=device, dtype=torch.float32)
83
+ / (grid_size[1] / base_size)
84
+ / interpolation_scale
85
+ )
86
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
87
+ grid = torch.stack(grid, dim=0)
88
+
89
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
90
+ pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type)
91
+ if cls_token and extra_tokens > 0:
92
+ pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
93
+ return pos_embed
94
+
95
+
96
+ def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"):
97
+ r"""
98
+ This function generates 2D sinusoidal positional embeddings from a grid.
99
+
100
+ Args:
101
+ embed_dim (`int`): The embedding dimension.
102
+ grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
103
+
104
+ Returns:
105
+ `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
106
+ """
107
+ if output_type == "np":
108
+ deprecation_message = (
109
+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
110
+ " `from_numpy` is no longer required."
111
+ " Pass `output_type='pt' to use the new version now."
112
+ )
113
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
114
+ raise ValueError("Not supported")
115
+ if embed_dim % 2 != 0:
116
+ raise ValueError("embed_dim must be divisible by 2")
117
+
118
+ # use half of dimensions to encode grid_h
119
+ emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) # (H*W, D/2)
120
+ emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) # (H*W, D/2)
121
+
122
+ emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
123
+ return emb
124
+
125
+
126
+ def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"):
127
+ """
128
+ This function generates 1D positional embeddings from a grid.
129
+
130
+ Args:
131
+ embed_dim (`int`): The embedding dimension `D`
132
+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
133
+
134
+ Returns:
135
+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
136
+ """
137
+ if output_type == "np":
138
+ deprecation_message = (
139
+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
140
+ " `from_numpy` is no longer required."
141
+ " Pass `output_type='pt' to use the new version now."
142
+ )
143
+ deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
144
+ raise ValueError("Not supported")
145
+ if embed_dim % 2 != 0:
146
+ raise ValueError("embed_dim must be divisible by 2")
147
+
148
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
149
+ omega /= embed_dim / 2.0
150
+ omega = 1.0 / 10000**omega # (D/2,)
151
+
152
+ pos = pos.reshape(-1) + phase # (M,)
153
+ out = torch.outer(pos, omega) # (M, D/2), outer product
154
+
155
+ emb_sin = torch.sin(out) # (M, D/2)
156
+ emb_cos = torch.cos(out) # (M, D/2)
157
+
158
+ emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
159
+ return emb
160
+
161
+
162
+ class PixcellUNIProjection(nn.Module):
163
+ """
164
+ Projects UNI embeddings. Also handles dropout for classifier-free guidance.
165
+
166
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
167
+ """
168
+
169
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1):
170
+ super().__init__()
171
+ if out_features is None:
172
+ out_features = hidden_size
173
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
174
+ if act_fn == "gelu_tanh":
175
+ self.act_1 = nn.GELU(approximate="tanh")
176
+ elif act_fn == "silu":
177
+ self.act_1 = nn.SiLU()
178
+ elif act_fn == "silu_fp32":
179
+ self.act_1 = FP32SiLU()
180
+ else:
181
+ raise ValueError(f"Unknown activation function: {act_fn}")
182
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
183
+
184
+ self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5))
185
+
186
+ def forward(self, caption):
187
+ hidden_states = self.linear_1(caption)
188
+ hidden_states = self.act_1(hidden_states)
189
+ hidden_states = self.linear_2(hidden_states)
190
+ return hidden_states
191
+
192
+ class UNIPosEmbed(nn.Module):
193
+ """
194
+ Adds positional embeddings to the UNI conditions.
195
+
196
+ Args:
197
+ height (`int`, defaults to `224`): The height of the image.
198
+ width (`int`, defaults to `224`): The width of the image.
199
+ patch_size (`int`, defaults to `16`): The size of the patches.
200
+ in_channels (`int`, defaults to `3`): The number of input channels.
201
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
202
+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
203
+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
204
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
205
+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
206
+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
207
+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ height=1,
213
+ width=1,
214
+ base_size=16,
215
+ embed_dim=768,
216
+ interpolation_scale=1,
217
+ pos_embed_type="sincos",
218
+ ):
219
+ super().__init__()
220
+
221
+ num_embeds = height*width
222
+ grid_size = int(num_embeds ** 0.5)
223
+
224
+ if pos_embed_type == "sincos":
225
+ y_pos_embed = pixcell_get_2d_sincos_pos_embed(
226
+ embed_dim,
227
+ grid_size,
228
+ base_size=base_size,
229
+ interpolation_scale=interpolation_scale,
230
+ output_type="pt",
231
+ phase = base_size // num_embeds
232
+ )
233
+ self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0))
234
+ else:
235
+ raise ValueError("`pos_embed_type` not supported")
236
+
237
+ def forward(self, uni_embeds):
238
+ return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype)
239
+
240
+
241
+
242
+ class PixCellTransformer2DModel(ModelMixin, ConfigMixin):
243
+ r"""
244
+ A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
245
+ https://arxiv.org/abs/2403.04692). Modified for the pathology domain.
246
+
247
+ Parameters:
248
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
249
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
250
+ in_channels (int, defaults to 4): The number of channels in the input.
251
+ out_channels (int, optional):
252
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
253
+ input.
254
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
255
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
256
+ norm_num_groups (int, optional, defaults to 32):
257
+ Number of groups for group normalization within Transformer blocks.
258
+ cross_attention_dim (int, optional):
259
+ The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
260
+ attention_bias (bool, optional, defaults to True):
261
+ Configure if the Transformer blocks' attention should contain a bias parameter.
262
+ sample_size (int, defaults to 128):
263
+ The width of the latent images. This parameter is fixed during training.
264
+ patch_size (int, defaults to 2):
265
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
266
+ activation_fn (str, optional, defaults to "gelu-approximate"):
267
+ Activation function to use in feed-forward networks within Transformer blocks.
268
+ num_embeds_ada_norm (int, optional, defaults to 1000):
269
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
270
+ inference.
271
+ upcast_attention (bool, optional, defaults to False):
272
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
273
+ norm_type (str, optional, defaults to "ada_norm_zero"):
274
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
275
+ norm_elementwise_affine (bool, optional, defaults to False):
276
+ If true, enables element-wise affine parameters in the normalization layers.
277
+ norm_eps (float, optional, defaults to 1e-6):
278
+ A small constant added to the denominator in normalization layers to prevent division by zero.
279
+ interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
280
+ use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
281
+ attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
282
+ caption_channels (int, optional, defaults to None):
283
+ Number of channels to use for projecting the caption embeddings.
284
+ use_linear_projection (bool, optional, defaults to False):
285
+ Deprecated argument. Will be removed in a future version.
286
+ num_vector_embeds (bool, optional, defaults to False):
287
+ Deprecated argument. Will be removed in a future version.
288
+ """
289
+
290
+ _supports_gradient_checkpointing = True
291
+ _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
292
+
293
+ @register_to_config
294
+ def __init__(
295
+ self,
296
+ num_attention_heads: int = 16,
297
+ attention_head_dim: int = 72,
298
+ in_channels: int = 4,
299
+ out_channels: Optional[int] = 8,
300
+ num_layers: int = 28,
301
+ dropout: float = 0.0,
302
+ norm_num_groups: int = 32,
303
+ cross_attention_dim: Optional[int] = 1152,
304
+ attention_bias: bool = True,
305
+ sample_size: int = 128,
306
+ patch_size: int = 2,
307
+ activation_fn: str = "gelu-approximate",
308
+ num_embeds_ada_norm: Optional[int] = 1000,
309
+ upcast_attention: bool = False,
310
+ norm_type: str = "ada_norm_single",
311
+ norm_elementwise_affine: bool = False,
312
+ norm_eps: float = 1e-6,
313
+ interpolation_scale: Optional[int] = None,
314
+ use_additional_conditions: Optional[bool] = None,
315
+ caption_channels: Optional[int] = None,
316
+ caption_num_tokens: int = 1,
317
+ attention_type: Optional[str] = "default",
318
+ ):
319
+ super().__init__()
320
+
321
+ # Validate inputs.
322
+ if norm_type != "ada_norm_single":
323
+ raise NotImplementedError(
324
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
325
+ )
326
+ elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
327
+ raise ValueError(
328
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
329
+ )
330
+
331
+ # Set some common variables used across the board.
332
+ self.attention_head_dim = attention_head_dim
333
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
334
+ self.out_channels = in_channels if out_channels is None else out_channels
335
+ if use_additional_conditions is None:
336
+ if sample_size == 128:
337
+ use_additional_conditions = True
338
+ else:
339
+ use_additional_conditions = False
340
+ self.use_additional_conditions = use_additional_conditions
341
+
342
+ self.gradient_checkpointing = False
343
+
344
+ # 2. Initialize the position embedding and transformer blocks.
345
+ self.height = self.config.sample_size
346
+ self.width = self.config.sample_size
347
+
348
+ interpolation_scale = (
349
+ self.config.interpolation_scale
350
+ if self.config.interpolation_scale is not None
351
+ else max(self.config.sample_size // 64, 1)
352
+ )
353
+ self.pos_embed = PatchEmbed(
354
+ height=self.config.sample_size,
355
+ width=self.config.sample_size,
356
+ patch_size=self.config.patch_size,
357
+ in_channels=self.config.in_channels,
358
+ embed_dim=self.inner_dim,
359
+ interpolation_scale=interpolation_scale,
360
+ )
361
+
362
+ self.transformer_blocks = nn.ModuleList(
363
+ [
364
+ BasicTransformerBlock(
365
+ self.inner_dim,
366
+ self.config.num_attention_heads,
367
+ self.config.attention_head_dim,
368
+ dropout=self.config.dropout,
369
+ cross_attention_dim=self.config.cross_attention_dim,
370
+ activation_fn=self.config.activation_fn,
371
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
372
+ attention_bias=self.config.attention_bias,
373
+ upcast_attention=self.config.upcast_attention,
374
+ norm_type=norm_type,
375
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
376
+ norm_eps=self.config.norm_eps,
377
+ attention_type=self.config.attention_type,
378
+ )
379
+ for _ in range(self.config.num_layers)
380
+ ]
381
+ )
382
+
383
+ # Initialize the positional embedding for the conditions for >1 UNI embeddings
384
+ if self.config.caption_num_tokens == 1:
385
+ self.y_pos_embed = None
386
+ else:
387
+ # 1:1 aspect ratio
388
+ self.uni_height = int(self.config.caption_num_tokens ** 0.5)
389
+ self.uni_width = int(self.config.caption_num_tokens ** 0.5)
390
+
391
+ self.y_pos_embed = UNIPosEmbed(
392
+ height=self.uni_height,
393
+ width=self.uni_width,
394
+ base_size=self.config.sample_size // self.config.patch_size,
395
+ embed_dim=self.config.caption_channels,
396
+ interpolation_scale=2, # Should this be fixed?
397
+ pos_embed_type="sincos", # This is fixed
398
+ )
399
+
400
+ # 3. Output blocks.
401
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
402
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
403
+ self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
404
+
405
+ self.adaln_single = AdaLayerNormSingle(
406
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
407
+ )
408
+ self.caption_projection = None
409
+ if self.config.caption_channels is not None:
410
+ self.caption_projection = PixcellUNIProjection(
411
+ in_features=self.config.caption_channels, hidden_size=self.inner_dim, num_tokens=self.config.caption_num_tokens,
412
+ )
413
+
414
+ def _set_gradient_checkpointing(self, module, value=False):
415
+ if hasattr(module, "gradient_checkpointing"):
416
+ module.gradient_checkpointing = value
417
+
418
+ @property
419
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
420
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
421
+ r"""
422
+ Returns:
423
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
424
+ indexed by its weight name.
425
+ """
426
+ # set recursively
427
+ processors = {}
428
+
429
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
430
+ if hasattr(module, "get_processor"):
431
+ processors[f"{name}.processor"] = module.get_processor()
432
+
433
+ for sub_name, child in module.named_children():
434
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
435
+
436
+ return processors
437
+
438
+ for name, module in self.named_children():
439
+ fn_recursive_add_processors(name, module, processors)
440
+
441
+ return processors
442
+
443
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
444
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
445
+ r"""
446
+ Sets the attention processor to use to compute attention.
447
+
448
+ Parameters:
449
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
450
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
451
+ for **all** `Attention` layers.
452
+
453
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
454
+ processor. This is strongly recommended when setting trainable attention processors.
455
+
456
+ """
457
+ count = len(self.attn_processors.keys())
458
+
459
+ if isinstance(processor, dict) and len(processor) != count:
460
+ raise ValueError(
461
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
462
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
463
+ )
464
+
465
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
466
+ if hasattr(module, "set_processor"):
467
+ if not isinstance(processor, dict):
468
+ module.set_processor(processor)
469
+ else:
470
+ module.set_processor(processor.pop(f"{name}.processor"))
471
+
472
+ for sub_name, child in module.named_children():
473
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
474
+
475
+ for name, module in self.named_children():
476
+ fn_recursive_attn_processor(name, module, processor)
477
+
478
+ def set_default_attn_processor(self):
479
+ """
480
+ Disables custom attention processors and sets the default attention implementation.
481
+
482
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
483
+ """
484
+ self.set_attn_processor(AttnProcessor())
485
+
486
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
487
+ def fuse_qkv_projections(self):
488
+ """
489
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
490
+ are fused. For cross-attention modules, key and value projection matrices are fused.
491
+
492
+ <Tip warning={true}>
493
+
494
+ This API is 🧪 experimental.
495
+
496
+ </Tip>
497
+ """
498
+ self.original_attn_processors = None
499
+
500
+ for _, attn_processor in self.attn_processors.items():
501
+ if "Added" in str(attn_processor.__class__.__name__):
502
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
503
+
504
+ self.original_attn_processors = self.attn_processors
505
+
506
+ for module in self.modules():
507
+ if isinstance(module, Attention):
508
+ module.fuse_projections(fuse=True)
509
+
510
+ self.set_attn_processor(FusedAttnProcessor2_0())
511
+
512
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
513
+ def unfuse_qkv_projections(self):
514
+ """Disables the fused QKV projection if enabled.
515
+
516
+ <Tip warning={true}>
517
+
518
+ This API is 🧪 experimental.
519
+
520
+ </Tip>
521
+
522
+ """
523
+ if self.original_attn_processors is not None:
524
+ self.set_attn_processor(self.original_attn_processors)
525
+
526
+ def forward(
527
+ self,
528
+ hidden_states: torch.Tensor,
529
+ encoder_hidden_states: Optional[torch.Tensor] = None,
530
+ timestep: Optional[torch.LongTensor] = None,
531
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
532
+ cross_attention_kwargs: Dict[str, Any] = None,
533
+ attention_mask: Optional[torch.Tensor] = None,
534
+ encoder_attention_mask: Optional[torch.Tensor] = None,
535
+ return_dict: bool = True,
536
+ ):
537
+ """
538
+ The [`PixCellTransformer2DModel`] forward method.
539
+
540
+ Args:
541
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
542
+ Input `hidden_states`.
543
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
544
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
545
+ self-attention.
546
+ timestep (`torch.LongTensor`, *optional*):
547
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
548
+ added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
549
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
550
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
551
+ `self.processor` in
552
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
553
+ attention_mask ( `torch.Tensor`, *optional*):
554
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
555
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
556
+ negative values to the attention scores corresponding to "discard" tokens.
557
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
558
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
559
+
560
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
561
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
562
+
563
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
564
+ above. This bias will be added to the cross-attention scores.
565
+ return_dict (`bool`, *optional*, defaults to `True`):
566
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
567
+ tuple.
568
+
569
+ Returns:
570
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
571
+ `tuple` where the first element is the sample tensor.
572
+ """
573
+ if self.use_additional_conditions and added_cond_kwargs is None:
574
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
575
+
576
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
577
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
578
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
579
+ # expects mask of shape:
580
+ # [batch, key_tokens]
581
+ # adds singleton query_tokens dimension:
582
+ # [batch, 1, key_tokens]
583
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
584
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
585
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
586
+ if attention_mask is not None and attention_mask.ndim == 2:
587
+ # assume that mask is expressed as:
588
+ # (1 = keep, 0 = discard)
589
+ # convert mask into a bias that can be added to attention scores:
590
+ # (keep = +0, discard = -10000.0)
591
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
592
+ attention_mask = attention_mask.unsqueeze(1)
593
+
594
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
595
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
596
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
597
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
598
+
599
+ # 1. Input
600
+ batch_size = hidden_states.shape[0]
601
+ height, width = (
602
+ hidden_states.shape[-2] // self.config.patch_size,
603
+ hidden_states.shape[-1] // self.config.patch_size,
604
+ )
605
+ hidden_states = self.pos_embed(hidden_states)
606
+
607
+ timestep, embedded_timestep = self.adaln_single(
608
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
609
+ )
610
+
611
+ if self.caption_projection is not None:
612
+ # Add positional embeddings to conditions if >1 UNI are given
613
+ if self.y_pos_embed is not None:
614
+ encoder_hidden_states = self.y_pos_embed(encoder_hidden_states)
615
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
616
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
617
+
618
+ # 2. Blocks
619
+ for block in self.transformer_blocks:
620
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
621
+
622
+ def create_custom_forward(module, return_dict=None):
623
+ def custom_forward(*inputs):
624
+ if return_dict is not None:
625
+ return module(*inputs, return_dict=return_dict)
626
+ else:
627
+ return module(*inputs)
628
+
629
+ return custom_forward
630
+
631
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
632
+ hidden_states = torch.utils.checkpoint.checkpoint(
633
+ create_custom_forward(block),
634
+ hidden_states,
635
+ attention_mask,
636
+ encoder_hidden_states,
637
+ encoder_attention_mask,
638
+ timestep,
639
+ cross_attention_kwargs,
640
+ None,
641
+ **ckpt_kwargs,
642
+ )
643
+ else:
644
+ hidden_states = block(
645
+ hidden_states,
646
+ attention_mask=attention_mask,
647
+ encoder_hidden_states=encoder_hidden_states,
648
+ encoder_attention_mask=encoder_attention_mask,
649
+ timestep=timestep,
650
+ cross_attention_kwargs=cross_attention_kwargs,
651
+ class_labels=None,
652
+ )
653
+
654
+ # 3. Output
655
+ shift, scale = (
656
+ self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
657
+ ).chunk(2, dim=1)
658
+ hidden_states = self.norm_out(hidden_states)
659
+ # Modulation
660
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
661
+ hidden_states = self.proj_out(hidden_states)
662
+ hidden_states = hidden_states.squeeze(1)
663
+
664
+ # unpatchify
665
+ hidden_states = hidden_states.reshape(
666
+ shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
667
+ )
668
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
669
+ output = hidden_states.reshape(
670
+ shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
671
+ )
672
+
673
+ if not return_dict:
674
+ return (output,)
675
+
676
+ return Transformer2DModelOutput(sample=output)
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.32.2",
4
+ "_name_or_path": "stabilityai/stable-diffusion-3.5-large",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 16,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 1024,
28
+ "scaling_factor": 1.5305,
29
+ "shift_factor": 0.0609,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": false,
37
+ "use_quant_conv": false
38
+ }