Diffusers
Safetensors
PixCellPipeline
AlexGraikos commited on
Commit
3f59797
·
verified ·
1 Parent(s): 7147c53

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -1122
pipeline.py DELETED
@@ -1,1122 +0,0 @@
1
- # Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- import urllib.parse as ul
17
- from typing import Callable, List, Optional, Tuple, Union
18
-
19
- import torch
20
- from torch import nn
21
-
22
- from diffusers.image_processor import PixArtImageProcessor
23
- from diffusers.models import AutoencoderKL
24
- from diffusers.schedulers import DPMSolverMultistepScheduler
25
- from diffusers.utils import (
26
- BACKENDS_MAPPING,
27
- deprecate,
28
- logging,
29
- replace_example_docstring,
30
- )
31
- from diffusers.utils.torch_utils import randn_tensor
32
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
33
-
34
- from pixcell_transformer_2d import PixCellTransformer2DModel
35
-
36
-
37
- from typing import Any, Dict, Optional, Union
38
-
39
-
40
- from diffusers.configuration_utils import ConfigMixin, register_to_config
41
- from diffusers.utils import is_torch_version, logging
42
- from diffusers.models.attention import BasicTransformerBlock
43
- from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
44
- from diffusers.models.embeddings import PatchEmbed
45
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
46
- from diffusers.models.modeling_utils import ModelMixin
47
- from diffusers.models.normalization import AdaLayerNormSingle
48
-
49
- from typing import List, Optional, Tuple, Union
50
-
51
- import numpy as np
52
- import torch
53
- import torch.nn.functional as F
54
- from torch import nn
55
-
56
- from diffusers.models.activations import deprecate, FP32SiLU
57
-
58
-
59
- def pixcell_get_2d_sincos_pos_embed(
60
- embed_dim,
61
- grid_size,
62
- cls_token=False,
63
- extra_tokens=0,
64
- interpolation_scale=1.0,
65
- base_size=16,
66
- device: Optional[torch.device] = None,
67
- phase=0,
68
- output_type: str = "np",
69
- ):
70
- """
71
- Creates 2D sinusoidal positional embeddings.
72
-
73
- Args:
74
- embed_dim (`int`):
75
- The embedding dimension.
76
- grid_size (`int`):
77
- The size of the grid height and width.
78
- cls_token (`bool`, defaults to `False`):
79
- Whether or not to add a classification token.
80
- extra_tokens (`int`, defaults to `0`):
81
- The number of extra tokens to add.
82
- interpolation_scale (`float`, defaults to `1.0`):
83
- The scale of the interpolation.
84
-
85
- Returns:
86
- pos_embed (`torch.Tensor`):
87
- Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
88
- embed_dim]` if using cls_token
89
- """
90
- if output_type == "np":
91
- deprecation_message = (
92
- "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
93
- " `from_numpy` is no longer required."
94
- " Pass `output_type='pt' to use the new version now."
95
- )
96
- deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
97
- raise ValueError("Not supported")
98
- if isinstance(grid_size, int):
99
- grid_size = (grid_size, grid_size)
100
-
101
- grid_h = (
102
- torch.arange(grid_size[0], device=device, dtype=torch.float32)
103
- / (grid_size[0] / base_size)
104
- / interpolation_scale
105
- )
106
- grid_w = (
107
- torch.arange(grid_size[1], device=device, dtype=torch.float32)
108
- / (grid_size[1] / base_size)
109
- / interpolation_scale
110
- )
111
- grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
112
- grid = torch.stack(grid, dim=0)
113
-
114
- grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
115
- pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type)
116
- if cls_token and extra_tokens > 0:
117
- pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
118
- return pos_embed
119
-
120
-
121
- def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"):
122
- r"""
123
- This function generates 2D sinusoidal positional embeddings from a grid.
124
-
125
- Args:
126
- embed_dim (`int`): The embedding dimension.
127
- grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
128
-
129
- Returns:
130
- `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
131
- """
132
- if output_type == "np":
133
- deprecation_message = (
134
- "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
135
- " `from_numpy` is no longer required."
136
- " Pass `output_type='pt' to use the new version now."
137
- )
138
- deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
139
- raise ValueError("Not supported")
140
- if embed_dim % 2 != 0:
141
- raise ValueError("embed_dim must be divisible by 2")
142
-
143
- # use half of dimensions to encode grid_h
144
- emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) # (H*W, D/2)
145
- emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) # (H*W, D/2)
146
-
147
- emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
148
- return emb
149
-
150
-
151
- def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"):
152
- """
153
- This function generates 1D positional embeddings from a grid.
154
-
155
- Args:
156
- embed_dim (`int`): The embedding dimension `D`
157
- pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
158
-
159
- Returns:
160
- `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
161
- """
162
- if output_type == "np":
163
- deprecation_message = (
164
- "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
165
- " `from_numpy` is no longer required."
166
- " Pass `output_type='pt' to use the new version now."
167
- )
168
- deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
169
- raise ValueError("Not supported")
170
- if embed_dim % 2 != 0:
171
- raise ValueError("embed_dim must be divisible by 2")
172
-
173
- omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
174
- omega /= embed_dim / 2.0
175
- omega = 1.0 / 10000**omega # (D/2,)
176
-
177
- pos = pos.reshape(-1) + phase # (M,)
178
- out = torch.outer(pos, omega) # (M, D/2), outer product
179
-
180
- emb_sin = torch.sin(out) # (M, D/2)
181
- emb_cos = torch.cos(out) # (M, D/2)
182
-
183
- emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
184
- return emb
185
-
186
-
187
- class PixcellUNIProjection(nn.Module):
188
- """
189
- Projects UNI embeddings. Also handles dropout for classifier-free guidance.
190
-
191
- Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
192
- """
193
-
194
- def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1):
195
- super().__init__()
196
- if out_features is None:
197
- out_features = hidden_size
198
- self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
199
- if act_fn == "gelu_tanh":
200
- self.act_1 = nn.GELU(approximate="tanh")
201
- elif act_fn == "silu":
202
- self.act_1 = nn.SiLU()
203
- elif act_fn == "silu_fp32":
204
- self.act_1 = FP32SiLU()
205
- else:
206
- raise ValueError(f"Unknown activation function: {act_fn}")
207
- self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
208
-
209
- self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5))
210
-
211
- def forward(self, caption):
212
- hidden_states = self.linear_1(caption)
213
- hidden_states = self.act_1(hidden_states)
214
- hidden_states = self.linear_2(hidden_states)
215
- return hidden_states
216
-
217
- class UNIPosEmbed(nn.Module):
218
- """
219
- Adds positional embeddings to the UNI conditions.
220
-
221
- Args:
222
- height (`int`, defaults to `224`): The height of the image.
223
- width (`int`, defaults to `224`): The width of the image.
224
- patch_size (`int`, defaults to `16`): The size of the patches.
225
- in_channels (`int`, defaults to `3`): The number of input channels.
226
- embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
227
- layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
228
- flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
229
- bias (`bool`, defaults to `True`): Whether or not to use bias.
230
- interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
231
- pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
232
- pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
233
- """
234
-
235
- def __init__(
236
- self,
237
- height=1,
238
- width=1,
239
- base_size=16,
240
- embed_dim=768,
241
- interpolation_scale=1,
242
- pos_embed_type="sincos",
243
- ):
244
- super().__init__()
245
-
246
- num_embeds = height*width
247
- grid_size = int(num_embeds ** 0.5)
248
-
249
- if pos_embed_type == "sincos":
250
- y_pos_embed = pixcell_get_2d_sincos_pos_embed(
251
- embed_dim,
252
- grid_size,
253
- base_size=base_size,
254
- interpolation_scale=interpolation_scale,
255
- output_type="pt",
256
- phase = base_size // num_embeds
257
- )
258
- self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0))
259
- else:
260
- raise ValueError("`pos_embed_type` not supported")
261
-
262
- def forward(self, uni_embeds):
263
- return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype)
264
-
265
-
266
-
267
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
268
-
269
-
270
- class PixCellTransformer2DModel(ModelMixin, ConfigMixin):
271
- r"""
272
- A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
273
- https://arxiv.org/abs/2403.04692). Modified for the pathology domain.
274
-
275
- Parameters:
276
- num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
277
- attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
278
- in_channels (int, defaults to 4): The number of channels in the input.
279
- out_channels (int, optional):
280
- The number of channels in the output. Specify this parameter if the output channel number differs from the
281
- input.
282
- num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
283
- dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
284
- norm_num_groups (int, optional, defaults to 32):
285
- Number of groups for group normalization within Transformer blocks.
286
- cross_attention_dim (int, optional):
287
- The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
288
- attention_bias (bool, optional, defaults to True):
289
- Configure if the Transformer blocks' attention should contain a bias parameter.
290
- sample_size (int, defaults to 128):
291
- The width of the latent images. This parameter is fixed during training.
292
- patch_size (int, defaults to 2):
293
- Size of the patches the model processes, relevant for architectures working on non-sequential data.
294
- activation_fn (str, optional, defaults to "gelu-approximate"):
295
- Activation function to use in feed-forward networks within Transformer blocks.
296
- num_embeds_ada_norm (int, optional, defaults to 1000):
297
- Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
298
- inference.
299
- upcast_attention (bool, optional, defaults to False):
300
- If true, upcasts the attention mechanism dimensions for potentially improved performance.
301
- norm_type (str, optional, defaults to "ada_norm_zero"):
302
- Specifies the type of normalization used, can be 'ada_norm_zero'.
303
- norm_elementwise_affine (bool, optional, defaults to False):
304
- If true, enables element-wise affine parameters in the normalization layers.
305
- norm_eps (float, optional, defaults to 1e-6):
306
- A small constant added to the denominator in normalization layers to prevent division by zero.
307
- interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
308
- use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
309
- attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
310
- caption_channels (int, optional, defaults to None):
311
- Number of channels to use for projecting the caption embeddings.
312
- use_linear_projection (bool, optional, defaults to False):
313
- Deprecated argument. Will be removed in a future version.
314
- num_vector_embeds (bool, optional, defaults to False):
315
- Deprecated argument. Will be removed in a future version.
316
- """
317
-
318
- _supports_gradient_checkpointing = True
319
- _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
320
-
321
- @register_to_config
322
- def __init__(
323
- self,
324
- num_attention_heads: int = 16,
325
- attention_head_dim: int = 72,
326
- in_channels: int = 4,
327
- out_channels: Optional[int] = 8,
328
- num_layers: int = 28,
329
- dropout: float = 0.0,
330
- norm_num_groups: int = 32,
331
- cross_attention_dim: Optional[int] = 1152,
332
- attention_bias: bool = True,
333
- sample_size: int = 128,
334
- patch_size: int = 2,
335
- activation_fn: str = "gelu-approximate",
336
- num_embeds_ada_norm: Optional[int] = 1000,
337
- upcast_attention: bool = False,
338
- norm_type: str = "ada_norm_single",
339
- norm_elementwise_affine: bool = False,
340
- norm_eps: float = 1e-6,
341
- interpolation_scale: Optional[int] = None,
342
- use_additional_conditions: Optional[bool] = None,
343
- caption_channels: Optional[int] = None,
344
- caption_num_tokens: int = 1,
345
- attention_type: Optional[str] = "default",
346
- ):
347
- super().__init__()
348
-
349
- # Validate inputs.
350
- if norm_type != "ada_norm_single":
351
- raise NotImplementedError(
352
- f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
353
- )
354
- elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
355
- raise ValueError(
356
- f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
357
- )
358
-
359
- # Set some common variables used across the board.
360
- self.attention_head_dim = attention_head_dim
361
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
362
- self.out_channels = in_channels if out_channels is None else out_channels
363
- if use_additional_conditions is None:
364
- if sample_size == 128:
365
- use_additional_conditions = True
366
- else:
367
- use_additional_conditions = False
368
- self.use_additional_conditions = use_additional_conditions
369
-
370
- self.gradient_checkpointing = False
371
-
372
- # 2. Initialize the position embedding and transformer blocks.
373
- self.height = self.config.sample_size
374
- self.width = self.config.sample_size
375
-
376
- interpolation_scale = (
377
- self.config.interpolation_scale
378
- if self.config.interpolation_scale is not None
379
- else max(self.config.sample_size // 64, 1)
380
- )
381
- self.pos_embed = PatchEmbed(
382
- height=self.config.sample_size,
383
- width=self.config.sample_size,
384
- patch_size=self.config.patch_size,
385
- in_channels=self.config.in_channels,
386
- embed_dim=self.inner_dim,
387
- interpolation_scale=interpolation_scale,
388
- )
389
-
390
- self.transformer_blocks = nn.ModuleList(
391
- [
392
- BasicTransformerBlock(
393
- self.inner_dim,
394
- self.config.num_attention_heads,
395
- self.config.attention_head_dim,
396
- dropout=self.config.dropout,
397
- cross_attention_dim=self.config.cross_attention_dim,
398
- activation_fn=self.config.activation_fn,
399
- num_embeds_ada_norm=self.config.num_embeds_ada_norm,
400
- attention_bias=self.config.attention_bias,
401
- upcast_attention=self.config.upcast_attention,
402
- norm_type=norm_type,
403
- norm_elementwise_affine=self.config.norm_elementwise_affine,
404
- norm_eps=self.config.norm_eps,
405
- attention_type=self.config.attention_type,
406
- )
407
- for _ in range(self.config.num_layers)
408
- ]
409
- )
410
-
411
- # Initialize the positional embedding for the conditions for >1 UNI embeddings
412
- if self.config.caption_num_tokens == 1:
413
- self.y_pos_embed = None
414
- else:
415
- # 1:1 aspect ratio
416
- self.uni_height = int(self.config.caption_num_tokens ** 0.5)
417
- self.uni_width = int(self.config.caption_num_tokens ** 0.5)
418
-
419
- self.y_pos_embed = UNIPosEmbed(
420
- height=self.uni_height,
421
- width=self.uni_width,
422
- base_size=self.config.sample_size // self.config.patch_size,
423
- embed_dim=self.config.caption_channels,
424
- interpolation_scale=2, # Should this be fixed?
425
- pos_embed_type="sincos", # This is fixed
426
- )
427
-
428
- # 3. Output blocks.
429
- self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
430
- self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
431
- self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
432
-
433
- self.adaln_single = AdaLayerNormSingle(
434
- self.inner_dim, use_additional_conditions=self.use_additional_conditions
435
- )
436
- self.caption_projection = None
437
- if self.config.caption_channels is not None:
438
- self.caption_projection = PixcellUNIProjection(
439
- in_features=self.config.caption_channels, hidden_size=self.inner_dim, num_tokens=self.config.caption_num_tokens,
440
- )
441
-
442
- def _set_gradient_checkpointing(self, module, value=False):
443
- if hasattr(module, "gradient_checkpointing"):
444
- module.gradient_checkpointing = value
445
-
446
- @property
447
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
448
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
449
- r"""
450
- Returns:
451
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
452
- indexed by its weight name.
453
- """
454
- # set recursively
455
- processors = {}
456
-
457
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
458
- if hasattr(module, "get_processor"):
459
- processors[f"{name}.processor"] = module.get_processor()
460
-
461
- for sub_name, child in module.named_children():
462
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
463
-
464
- return processors
465
-
466
- for name, module in self.named_children():
467
- fn_recursive_add_processors(name, module, processors)
468
-
469
- return processors
470
-
471
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
472
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
473
- r"""
474
- Sets the attention processor to use to compute attention.
475
-
476
- Parameters:
477
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
478
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
479
- for **all** `Attention` layers.
480
-
481
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
482
- processor. This is strongly recommended when setting trainable attention processors.
483
-
484
- """
485
- count = len(self.attn_processors.keys())
486
-
487
- if isinstance(processor, dict) and len(processor) != count:
488
- raise ValueError(
489
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
490
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
491
- )
492
-
493
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
494
- if hasattr(module, "set_processor"):
495
- if not isinstance(processor, dict):
496
- module.set_processor(processor)
497
- else:
498
- module.set_processor(processor.pop(f"{name}.processor"))
499
-
500
- for sub_name, child in module.named_children():
501
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
502
-
503
- for name, module in self.named_children():
504
- fn_recursive_attn_processor(name, module, processor)
505
-
506
- def set_default_attn_processor(self):
507
- """
508
- Disables custom attention processors and sets the default attention implementation.
509
-
510
- Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
511
- """
512
- self.set_attn_processor(AttnProcessor())
513
-
514
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
515
- def fuse_qkv_projections(self):
516
- """
517
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
518
- are fused. For cross-attention modules, key and value projection matrices are fused.
519
-
520
- <Tip warning={true}>
521
-
522
- This API is 🧪 experimental.
523
-
524
- </Tip>
525
- """
526
- self.original_attn_processors = None
527
-
528
- for _, attn_processor in self.attn_processors.items():
529
- if "Added" in str(attn_processor.__class__.__name__):
530
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
531
-
532
- self.original_attn_processors = self.attn_processors
533
-
534
- for module in self.modules():
535
- if isinstance(module, Attention):
536
- module.fuse_projections(fuse=True)
537
-
538
- self.set_attn_processor(FusedAttnProcessor2_0())
539
-
540
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
541
- def unfuse_qkv_projections(self):
542
- """Disables the fused QKV projection if enabled.
543
-
544
- <Tip warning={true}>
545
-
546
- This API is 🧪 experimental.
547
-
548
- </Tip>
549
-
550
- """
551
- if self.original_attn_processors is not None:
552
- self.set_attn_processor(self.original_attn_processors)
553
-
554
- def forward(
555
- self,
556
- hidden_states: torch.Tensor,
557
- encoder_hidden_states: Optional[torch.Tensor] = None,
558
- timestep: Optional[torch.LongTensor] = None,
559
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
560
- cross_attention_kwargs: Dict[str, Any] = None,
561
- attention_mask: Optional[torch.Tensor] = None,
562
- encoder_attention_mask: Optional[torch.Tensor] = None,
563
- return_dict: bool = True,
564
- ):
565
- """
566
- The [`PixCellTransformer2DModel`] forward method.
567
-
568
- Args:
569
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
570
- Input `hidden_states`.
571
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
572
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
573
- self-attention.
574
- timestep (`torch.LongTensor`, *optional*):
575
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
576
- added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
577
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
578
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
579
- `self.processor` in
580
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
581
- attention_mask ( `torch.Tensor`, *optional*):
582
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
583
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
584
- negative values to the attention scores corresponding to "discard" tokens.
585
- encoder_attention_mask ( `torch.Tensor`, *optional*):
586
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
587
-
588
- * Mask `(batch, sequence_length)` True = keep, False = discard.
589
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
590
-
591
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
592
- above. This bias will be added to the cross-attention scores.
593
- return_dict (`bool`, *optional*, defaults to `True`):
594
- Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
595
- tuple.
596
-
597
- Returns:
598
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
599
- `tuple` where the first element is the sample tensor.
600
- """
601
- if self.use_additional_conditions and added_cond_kwargs is None:
602
- raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
603
-
604
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
605
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
606
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
607
- # expects mask of shape:
608
- # [batch, key_tokens]
609
- # adds singleton query_tokens dimension:
610
- # [batch, 1, key_tokens]
611
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
612
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
613
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
614
- if attention_mask is not None and attention_mask.ndim == 2:
615
- # assume that mask is expressed as:
616
- # (1 = keep, 0 = discard)
617
- # convert mask into a bias that can be added to attention scores:
618
- # (keep = +0, discard = -10000.0)
619
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
620
- attention_mask = attention_mask.unsqueeze(1)
621
-
622
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
623
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
624
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
625
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
626
-
627
- # 1. Input
628
- batch_size = hidden_states.shape[0]
629
- height, width = (
630
- hidden_states.shape[-2] // self.config.patch_size,
631
- hidden_states.shape[-1] // self.config.patch_size,
632
- )
633
- hidden_states = self.pos_embed(hidden_states)
634
-
635
- timestep, embedded_timestep = self.adaln_single(
636
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
637
- )
638
-
639
- if self.caption_projection is not None:
640
- # Add positional embeddings to conditions if >1 UNI are given
641
- if self.y_pos_embed is not None:
642
- encoder_hidden_states = self.y_pos_embed(encoder_hidden_states)
643
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
644
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
645
-
646
- # 2. Blocks
647
- for block in self.transformer_blocks:
648
- if torch.is_grad_enabled() and self.gradient_checkpointing:
649
-
650
- def create_custom_forward(module, return_dict=None):
651
- def custom_forward(*inputs):
652
- if return_dict is not None:
653
- return module(*inputs, return_dict=return_dict)
654
- else:
655
- return module(*inputs)
656
-
657
- return custom_forward
658
-
659
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
660
- hidden_states = torch.utils.checkpoint.checkpoint(
661
- create_custom_forward(block),
662
- hidden_states,
663
- attention_mask,
664
- encoder_hidden_states,
665
- encoder_attention_mask,
666
- timestep,
667
- cross_attention_kwargs,
668
- None,
669
- **ckpt_kwargs,
670
- )
671
- else:
672
- hidden_states = block(
673
- hidden_states,
674
- attention_mask=attention_mask,
675
- encoder_hidden_states=encoder_hidden_states,
676
- encoder_attention_mask=encoder_attention_mask,
677
- timestep=timestep,
678
- cross_attention_kwargs=cross_attention_kwargs,
679
- class_labels=None,
680
- )
681
-
682
- # 3. Output
683
- shift, scale = (
684
- self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
685
- ).chunk(2, dim=1)
686
- hidden_states = self.norm_out(hidden_states)
687
- # Modulation
688
- hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
689
- hidden_states = self.proj_out(hidden_states)
690
- hidden_states = hidden_states.squeeze(1)
691
-
692
- # unpatchify
693
- hidden_states = hidden_states.reshape(
694
- shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
695
- )
696
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
697
- output = hidden_states.reshape(
698
- shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
699
- )
700
-
701
- if not return_dict:
702
- return (output,)
703
-
704
- return Transformer2DModelOutput(sample=output)
705
-
706
-
707
-
708
-
709
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
710
-
711
-
712
- EXAMPLE_DOC_STRING = """
713
- Examples:
714
- ```py
715
- >>> import torch
716
- >>> from diffusers import PixCellSigmaPipeline
717
-
718
- >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too.
719
- >>> pipe = PixArtSigmaPipeline.from_pretrained(
720
- ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
721
- ... )
722
- >>> # Enable memory optimizations.
723
- >>> # pipe.enable_model_cpu_offload()
724
-
725
- >>> prompt = "A small cactus with a happy face in the Sahara desert."
726
- >>> image = pipe(prompt).images[0]
727
- ```
728
- """
729
-
730
-
731
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
732
- def retrieve_timesteps(
733
- scheduler,
734
- num_inference_steps: Optional[int] = None,
735
- device: Optional[Union[str, torch.device]] = None,
736
- timesteps: Optional[List[int]] = None,
737
- sigmas: Optional[List[float]] = None,
738
- **kwargs,
739
- ):
740
- r"""
741
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
742
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
743
-
744
- Args:
745
- scheduler (`SchedulerMixin`):
746
- The scheduler to get timesteps from.
747
- num_inference_steps (`int`):
748
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
749
- must be `None`.
750
- device (`str` or `torch.device`, *optional*):
751
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
752
- timesteps (`List[int]`, *optional*):
753
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
754
- `num_inference_steps` and `sigmas` must be `None`.
755
- sigmas (`List[float]`, *optional*):
756
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
757
- `num_inference_steps` and `timesteps` must be `None`.
758
-
759
- Returns:
760
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
761
- second element is the number of inference steps.
762
- """
763
- if timesteps is not None and sigmas is not None:
764
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
765
- if timesteps is not None:
766
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
767
- if not accepts_timesteps:
768
- raise ValueError(
769
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
770
- f" timestep schedules. Please check whether you are using the correct scheduler."
771
- )
772
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
773
- timesteps = scheduler.timesteps
774
- num_inference_steps = len(timesteps)
775
- elif sigmas is not None:
776
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
777
- if not accept_sigmas:
778
- raise ValueError(
779
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
780
- f" sigmas schedules. Please check whether you are using the correct scheduler."
781
- )
782
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
783
- timesteps = scheduler.timesteps
784
- num_inference_steps = len(timesteps)
785
- else:
786
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
787
- timesteps = scheduler.timesteps
788
- return timesteps, num_inference_steps
789
-
790
-
791
- class PixCellPipeline(DiffusionPipeline):
792
- r"""
793
- Pipeline for SSL-to-image generation using PixCell.
794
- """
795
-
796
- model_cpu_offload_seq = "transformer->vae"
797
-
798
- def __init__(
799
- self,
800
- vae: AutoencoderKL,
801
- transformer: PixCellTransformer2DModel,
802
- scheduler: DPMSolverMultistepScheduler,
803
- ):
804
- super().__init__()
805
-
806
- self.register_modules(
807
- vae=vae, transformer=transformer, scheduler=scheduler
808
- )
809
-
810
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
811
- self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
812
-
813
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
814
- def prepare_extra_step_kwargs(self, generator, eta):
815
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
816
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
817
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
818
- # and should be between [0, 1]
819
-
820
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
821
- extra_step_kwargs = {}
822
- if accepts_eta:
823
- extra_step_kwargs["eta"] = eta
824
-
825
- # check if the scheduler accepts generator
826
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
827
- if accepts_generator:
828
- extra_step_kwargs["generator"] = generator
829
- return extra_step_kwargs
830
-
831
- def get_unconditional_embedding(self, batch_size=1):
832
- # Unconditional embedding is learned
833
- uncond = self.transformer.caption_projection.uncond_embedding.clone().tile(batch_size,1,1)
834
- return uncond
835
-
836
- # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
837
- def check_inputs(
838
- self,
839
- height,
840
- width,
841
- callback_steps,
842
- uni_embeds=None,
843
- negative_uni_embeds=None,
844
- guidance_scale=None,
845
- ):
846
- if height % 8 != 0 or width % 8 != 0:
847
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
848
-
849
- if (callback_steps is None) or (
850
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
851
- ):
852
- raise ValueError(
853
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
854
- f" {type(callback_steps)}."
855
- )
856
-
857
- if uni_embeds is None:
858
- raise ValueError(
859
- "Provide a UNI embedding `uni_embeds`."
860
- )
861
- elif len(uni_embeds.shape) != 3:
862
- raise ValueError(
863
- "UNI embedding given is not in (B,N,D)."
864
- )
865
- elif uni_embeds.shape[1] != self.transformer.config.caption_num_tokens:
866
- raise ValueError(
867
- f"Number of UNI embeddings must match the ones used in training ({self.transformer.config.caption_num_tokens})."
868
- )
869
- elif uni_embeds.shape[2] != self.transformer.config.caption_channels:
870
- raise ValueError(
871
- "UNI embedding given has incorrect dimenions."
872
- )
873
-
874
- if guidance_scale > 1.0:
875
- if negative_uni_embeds is None:
876
- raise ValueError(
877
- "Provide a negative UNI embedding `negative_uni_embeds`."
878
- )
879
- elif len(negative_uni_embeds.shape) != 3:
880
- raise ValueError(
881
- "Negative UNI embedding given is not in (B,N,D)."
882
- )
883
- elif negative_uni_embeds.shape[1] != self.transformer.config.caption_num_tokens:
884
- raise ValueError(
885
- f"Number of negative UNI embeddings must match the ones used in training ({self.transformer.config.caption_num_tokens})."
886
- )
887
- elif negative_uni_embeds.shape[2] != self.transformer.config.caption_channels:
888
- raise ValueError(
889
- "Negative UNI embedding given has incorrect dimenions."
890
- )
891
-
892
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
893
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
894
- shape = (
895
- batch_size,
896
- num_channels_latents,
897
- int(height) // self.vae_scale_factor,
898
- int(width) // self.vae_scale_factor,
899
- )
900
- if isinstance(generator, list) and len(generator) != batch_size:
901
- raise ValueError(
902
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
903
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
904
- )
905
-
906
- if latents is None:
907
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
908
- else:
909
- latents = latents.to(device)
910
-
911
- # scale the initial noise by the standard deviation required by the scheduler
912
- latents = latents * self.scheduler.init_noise_sigma
913
- return latents
914
-
915
- @torch.no_grad()
916
- @replace_example_docstring(EXAMPLE_DOC_STRING)
917
- def __call__(
918
- self,
919
- num_inference_steps: int = 20,
920
- timesteps: List[int] = None,
921
- sigmas: List[float] = None,
922
- guidance_scale: float = 1.5,
923
- num_images_per_prompt: Optional[int] = 1,
924
- height: Optional[int] = None,
925
- width: Optional[int] = None,
926
- eta: float = 0.0,
927
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
928
- latents: Optional[torch.Tensor] = None,
929
- uni_embeds: Optional[torch.Tensor] = None,
930
- negative_uni_embeds: Optional[torch.Tensor] = None,
931
- output_type: Optional[str] = "pil",
932
- return_dict: bool = True,
933
- callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
934
- callback_steps: int = 1,
935
- **kwargs,
936
- ) -> Union[ImagePipelineOutput, Tuple]:
937
- """
938
- Function invoked when calling the pipeline for generation.
939
-
940
- Args:
941
- num_inference_steps (`int`, *optional*, defaults to 100):
942
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
943
- expense of slower inference.
944
- timesteps (`List[int]`, *optional*):
945
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
946
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
947
- passed will be used. Must be in descending order.
948
- sigmas (`List[float]`, *optional*):
949
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
950
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
951
- will be used.
952
- guidance_scale (`float`, *optional*, defaults to 4.5):
953
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
954
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
955
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
956
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
957
- usually at the expense of lower image quality.
958
- num_images_per_prompt (`int`, *optional*, defaults to 1):
959
- The number of images to generate per prompt.
960
- height (`int`, *optional*, defaults to self.unet.config.sample_size):
961
- The height in pixels of the generated image.
962
- width (`int`, *optional*, defaults to self.unet.config.sample_size):
963
- The width in pixels of the generated image.
964
- eta (`float`, *optional*, defaults to 0.0):
965
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
966
- [`schedulers.DDIMScheduler`], will be ignored for others.
967
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
968
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
969
- to make generation deterministic.
970
- latents (`torch.Tensor`, *optional*):
971
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
972
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
973
- tensor will ge generated by sampling using the supplied random `generator`.
974
- uni_embeds (`torch.Tensor`, *optional*):
975
- Pre-generated UNI embeddings.
976
- negative_uni_embeds (`torch.Tensor`, *optional*):
977
- Pre-generated negative UNI embeddings.
978
- output_type (`str`, *optional*, defaults to `"pil"`):
979
- The output format of the generate image. Choose between
980
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
981
- return_dict (`bool`, *optional*, defaults to `True`):
982
- Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
983
- callback (`Callable`, *optional*):
984
- A function that will be called every `callback_steps` steps during inference. The function will be
985
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
986
- callback_steps (`int`, *optional*, defaults to 1):
987
- The frequency at which the `callback` function will be called. If not specified, the callback will be
988
- called at every step.
989
-
990
- Examples:
991
-
992
- Returns:
993
- [`~pipelines.ImagePipelineOutput`] or `tuple`:
994
- If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
995
- returned where the first element is a list with the generated images
996
- """
997
- # 1. Check inputs. Raise error if not correct
998
- height = height or self.transformer.config.sample_size * self.vae_scale_factor
999
- width = width or self.transformer.config.sample_size * self.vae_scale_factor
1000
-
1001
- self.check_inputs(
1002
- height,
1003
- width,
1004
- callback_steps,
1005
- uni_embeds,
1006
- negative_uni_embeds,
1007
- guidance_scale,
1008
- )
1009
-
1010
- # 2. Default height and width to transformer
1011
- batch_size = uni_embeds.shape[0]
1012
-
1013
- device = self._execution_device
1014
-
1015
- # 3. Handle UNI conditioning
1016
-
1017
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1018
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1019
- # corresponds to doing no classifier free guidance.
1020
- do_classifier_free_guidance = guidance_scale > 1.0
1021
-
1022
- uni_embeds = uni_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1023
- if do_classifier_free_guidance:
1024
- negative_uni_embeds = negative_uni_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1025
- uni_embeds = torch.cat([negative_uni_embeds, uni_embeds], dim=0)
1026
-
1027
-
1028
- # 4. Prepare timesteps
1029
- timesteps, num_inference_steps = retrieve_timesteps(
1030
- self.scheduler, num_inference_steps, device, timesteps, sigmas
1031
- )
1032
-
1033
- # 5. Prepare latents.
1034
- latent_channels = self.transformer.config.in_channels
1035
- latents = self.prepare_latents(
1036
- batch_size * num_images_per_prompt,
1037
- latent_channels,
1038
- height,
1039
- width,
1040
- uni_embeds.dtype,
1041
- device,
1042
- generator,
1043
- latents,
1044
- )
1045
-
1046
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1047
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1048
-
1049
- added_cond_kwargs = {}
1050
-
1051
- # 7. Denoising loop
1052
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1053
-
1054
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1055
- for i, t in enumerate(timesteps):
1056
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1057
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1058
-
1059
- current_timestep = t
1060
- if not torch.is_tensor(current_timestep):
1061
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1062
- # This would be a good case for the `match` statement (Python 3.10+)
1063
- is_mps = latent_model_input.device.type == "mps"
1064
- if isinstance(current_timestep, float):
1065
- dtype = torch.float32 if is_mps else torch.float64
1066
- else:
1067
- dtype = torch.int32 if is_mps else torch.int64
1068
- current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
1069
- elif len(current_timestep.shape) == 0:
1070
- current_timestep = current_timestep[None].to(latent_model_input.device)
1071
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1072
- current_timestep = current_timestep.expand(latent_model_input.shape[0])
1073
-
1074
- # predict noise model_output
1075
- noise_pred = self.transformer(
1076
- latent_model_input,
1077
- encoder_hidden_states=uni_embeds,
1078
- timestep=current_timestep,
1079
- added_cond_kwargs=added_cond_kwargs,
1080
- return_dict=False,
1081
- )[0]
1082
-
1083
- # perform guidance
1084
- if do_classifier_free_guidance:
1085
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1086
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1087
-
1088
- # learned sigma
1089
- if self.transformer.config.out_channels // 2 == latent_channels:
1090
- noise_pred = noise_pred.chunk(2, dim=1)[0]
1091
- else:
1092
- noise_pred = noise_pred
1093
-
1094
- # compute previous image: x_t -> x_t-1
1095
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1096
-
1097
- # call the callback, if provided
1098
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1099
- progress_bar.update()
1100
- if callback is not None and i % callback_steps == 0:
1101
- step_idx = i // getattr(self.scheduler, "order", 1)
1102
- callback(step_idx, t, latents)
1103
-
1104
- if not output_type == "latent":
1105
- vae_scale = self.vae.config.scaling_factor
1106
- vae_shift = getattr(self.vae.config, "shift_factor", 0)
1107
-
1108
- image = self.vae.decode((latents / vae_scale) + vae_shift, return_dict=False)[0]
1109
-
1110
- else:
1111
- image = latents
1112
-
1113
- if not output_type == "latent":
1114
- image = self.image_processor.postprocess(image, output_type=output_type)
1115
-
1116
- # Offload all models
1117
- self.maybe_free_model_hooks()
1118
-
1119
- if not return_dict:
1120
- return (image,)
1121
-
1122
- return ImagePipelineOutput(images=image)