Create models.unet_3d_condition.py

#80
Files changed (1) hide show
  1. unet/models.unet_3d_condition.py +500 -0
unet/models.unet_3d_condition.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.transformer_temporal import TransformerTemporalModel
27
+ from .unet_3d_blocks import (
28
+ CrossAttnDownBlock3D,
29
+ CrossAttnUpBlock3D,
30
+ DownBlock3D,
31
+ UNetMidBlock3DCrossAttn,
32
+ UpBlock3D,
33
+ get_down_block,
34
+ get_up_block,
35
+ transformer_g_c
36
+ )
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ @dataclass
43
+ class UNet3DConditionOutput(BaseOutput):
44
+ """
45
+ Args:
46
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
47
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
48
+ """
49
+
50
+ sample: torch.FloatTensor
51
+
52
+
53
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
54
+ r"""
55
+ UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
56
+ and returns sample shaped output.
57
+
58
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
59
+ implements for all the models (such as downloading or saving, etc.)
60
+
61
+ Parameters:
62
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
63
+ Height and width of input/output sample.
64
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
65
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
66
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
67
+ The tuple of downsample blocks to use.
68
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
69
+ The tuple of upsample blocks to use.
70
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
71
+ The tuple of output channels for each block.
72
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
73
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
74
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
75
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
76
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
77
+ If `None`, it will skip the normalization and activation layers in post-processing
78
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
80
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
81
+ """
82
+
83
+ _supports_gradient_checkpointing = True
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ sample_size: Optional[int] = None,
89
+ in_channels: int = 4,
90
+ out_channels: int = 4,
91
+ down_block_types: Tuple[str] = (
92
+ "CrossAttnDownBlock3D",
93
+ "CrossAttnDownBlock3D",
94
+ "CrossAttnDownBlock3D",
95
+ "DownBlock3D",
96
+ ),
97
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
98
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
99
+ layers_per_block: int = 2,
100
+ downsample_padding: int = 1,
101
+ mid_block_scale_factor: float = 1,
102
+ act_fn: str = "silu",
103
+ norm_num_groups: Optional[int] = 32,
104
+ norm_eps: float = 1e-5,
105
+ cross_attention_dim: int = 1024,
106
+ attention_head_dim: Union[int, Tuple[int]] = 64,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.sample_size = sample_size
111
+ self.gradient_checkpointing = False
112
+ # Check inputs
113
+ if len(down_block_types) != len(up_block_types):
114
+ raise ValueError(
115
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
116
+ )
117
+
118
+ if len(block_out_channels) != len(down_block_types):
119
+ raise ValueError(
120
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
121
+ )
122
+
123
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
124
+ raise ValueError(
125
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
126
+ )
127
+
128
+ # input
129
+ conv_in_kernel = 3
130
+ conv_out_kernel = 3
131
+ conv_in_padding = (conv_in_kernel - 1) // 2
132
+ self.conv_in = nn.Conv2d(
133
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
134
+ )
135
+
136
+ # time
137
+ time_embed_dim = block_out_channels[0] * 4
138
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
139
+ timestep_input_dim = block_out_channels[0]
140
+
141
+ self.time_embedding = TimestepEmbedding(
142
+ timestep_input_dim,
143
+ time_embed_dim,
144
+ act_fn=act_fn,
145
+ )
146
+
147
+ self.transformer_in = TransformerTemporalModel(
148
+ num_attention_heads=8,
149
+ attention_head_dim=attention_head_dim,
150
+ in_channels=block_out_channels[0],
151
+ num_layers=1,
152
+ )
153
+
154
+ # class embedding
155
+ self.down_blocks = nn.ModuleList([])
156
+ self.up_blocks = nn.ModuleList([])
157
+
158
+ if isinstance(attention_head_dim, int):
159
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
160
+
161
+ # down
162
+ output_channel = block_out_channels[0]
163
+ for i, down_block_type in enumerate(down_block_types):
164
+ input_channel = output_channel
165
+ output_channel = block_out_channels[i]
166
+ is_final_block = i == len(block_out_channels) - 1
167
+
168
+ down_block = get_down_block(
169
+ down_block_type,
170
+ num_layers=layers_per_block,
171
+ in_channels=input_channel,
172
+ out_channels=output_channel,
173
+ temb_channels=time_embed_dim,
174
+ add_downsample=not is_final_block,
175
+ resnet_eps=norm_eps,
176
+ resnet_act_fn=act_fn,
177
+ resnet_groups=norm_num_groups,
178
+ cross_attention_dim=cross_attention_dim,
179
+ attn_num_head_channels=attention_head_dim[i],
180
+ downsample_padding=downsample_padding,
181
+ dual_cross_attention=False,
182
+ )
183
+ self.down_blocks.append(down_block)
184
+
185
+ # mid
186
+ self.mid_block = UNetMidBlock3DCrossAttn(
187
+ in_channels=block_out_channels[-1],
188
+ temb_channels=time_embed_dim,
189
+ resnet_eps=norm_eps,
190
+ resnet_act_fn=act_fn,
191
+ output_scale_factor=mid_block_scale_factor,
192
+ cross_attention_dim=cross_attention_dim,
193
+ attn_num_head_channels=attention_head_dim[-1],
194
+ resnet_groups=norm_num_groups,
195
+ dual_cross_attention=False,
196
+ )
197
+
198
+ # count how many layers upsample the images
199
+ self.num_upsamplers = 0
200
+
201
+ # up
202
+ reversed_block_out_channels = list(reversed(block_out_channels))
203
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
204
+
205
+ output_channel = reversed_block_out_channels[0]
206
+ for i, up_block_type in enumerate(up_block_types):
207
+ is_final_block = i == len(block_out_channels) - 1
208
+
209
+ prev_output_channel = output_channel
210
+ output_channel = reversed_block_out_channels[i]
211
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
212
+
213
+ # add upsample block for all BUT final layer
214
+ if not is_final_block:
215
+ add_upsample = True
216
+ self.num_upsamplers += 1
217
+ else:
218
+ add_upsample = False
219
+
220
+ up_block = get_up_block(
221
+ up_block_type,
222
+ num_layers=layers_per_block + 1,
223
+ in_channels=input_channel,
224
+ out_channels=output_channel,
225
+ prev_output_channel=prev_output_channel,
226
+ temb_channels=time_embed_dim,
227
+ add_upsample=add_upsample,
228
+ resnet_eps=norm_eps,
229
+ resnet_act_fn=act_fn,
230
+ resnet_groups=norm_num_groups,
231
+ cross_attention_dim=cross_attention_dim,
232
+ attn_num_head_channels=reversed_attention_head_dim[i],
233
+ dual_cross_attention=False,
234
+ )
235
+ self.up_blocks.append(up_block)
236
+ prev_output_channel = output_channel
237
+
238
+ # out
239
+ if norm_num_groups is not None:
240
+ self.conv_norm_out = nn.GroupNorm(
241
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
242
+ )
243
+ self.conv_act = nn.SiLU()
244
+ else:
245
+ self.conv_norm_out = None
246
+ self.conv_act = None
247
+
248
+ conv_out_padding = (conv_out_kernel - 1) // 2
249
+ self.conv_out = nn.Conv2d(
250
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
251
+ )
252
+
253
+ def set_attention_slice(self, slice_size):
254
+ r"""
255
+ Enable sliced attention computation.
256
+
257
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
258
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
259
+
260
+ Args:
261
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
262
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
263
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
264
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
265
+ must be a multiple of `slice_size`.
266
+ """
267
+ sliceable_head_dims = []
268
+
269
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
270
+ if hasattr(module, "set_attention_slice"):
271
+ sliceable_head_dims.append(module.sliceable_head_dim)
272
+
273
+ for child in module.children():
274
+ fn_recursive_retrieve_slicable_dims(child)
275
+
276
+ # retrieve number of attention layers
277
+ for module in self.children():
278
+ fn_recursive_retrieve_slicable_dims(module)
279
+
280
+ num_slicable_layers = len(sliceable_head_dims)
281
+
282
+ if slice_size == "auto":
283
+ # half the attention head size is usually a good trade-off between
284
+ # speed and memory
285
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
286
+ elif slice_size == "max":
287
+ # make smallest slice possible
288
+ slice_size = num_slicable_layers * [1]
289
+
290
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
291
+
292
+ if len(slice_size) != len(sliceable_head_dims):
293
+ raise ValueError(
294
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
295
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
296
+ )
297
+
298
+ for i in range(len(slice_size)):
299
+ size = slice_size[i]
300
+ dim = sliceable_head_dims[i]
301
+ if size is not None and size > dim:
302
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
303
+
304
+ # Recursively walk through all the children.
305
+ # Any children which exposes the set_attention_slice method
306
+ # gets the message
307
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
308
+ if hasattr(module, "set_attention_slice"):
309
+ module.set_attention_slice(slice_size.pop())
310
+
311
+ for child in module.children():
312
+ fn_recursive_set_attention_slice(child, slice_size)
313
+
314
+ reversed_slice_size = list(reversed(slice_size))
315
+ for module in self.children():
316
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
317
+
318
+ def _set_gradient_checkpointing(self, value=False):
319
+ self.gradient_checkpointing = value
320
+ self.mid_block.gradient_checkpointing = value
321
+ for module in self.down_blocks + self.up_blocks:
322
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
323
+ module.gradient_checkpointing = value
324
+
325
+ def forward(
326
+ self,
327
+ sample: torch.FloatTensor,
328
+ timestep: Union[torch.Tensor, float, int],
329
+ encoder_hidden_states: torch.Tensor,
330
+ class_labels: Optional[torch.Tensor] = None,
331
+ timestep_cond: Optional[torch.Tensor] = None,
332
+ attention_mask: Optional[torch.Tensor] = None,
333
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
334
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
335
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
336
+ return_dict: bool = True,
337
+ ) -> Union[UNet3DConditionOutput, Tuple]:
338
+ r"""
339
+ Args:
340
+ sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
341
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
342
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
343
+ return_dict (`bool`, *optional*, defaults to `True`):
344
+ Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
345
+ cross_attention_kwargs (`dict`, *optional*):
346
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
347
+ `self.processor` in
348
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
349
+
350
+ Returns:
351
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
352
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
353
+ returning a tuple, the first element is the sample tensor.
354
+ """
355
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
356
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
357
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
358
+ # on the fly if necessary.
359
+ default_overall_up_factor = 2**self.num_upsamplers
360
+
361
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
362
+ forward_upsample_size = False
363
+ upsample_size = None
364
+
365
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
366
+ logger.info("Forward upsample size to force interpolation output size.")
367
+ forward_upsample_size = True
368
+
369
+ # prepare attention_mask
370
+ if attention_mask is not None:
371
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
372
+ attention_mask = attention_mask.unsqueeze(1)
373
+
374
+ # 1. time
375
+ timesteps = timestep
376
+ if not torch.is_tensor(timesteps):
377
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
378
+ # This would be a good case for the `match` statement (Python 3.10+)
379
+ is_mps = sample.device.type == "mps"
380
+ if isinstance(timestep, float):
381
+ dtype = torch.float32 if is_mps else torch.float64
382
+ else:
383
+ dtype = torch.int32 if is_mps else torch.int64
384
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
385
+ elif len(timesteps.shape) == 0:
386
+ timesteps = timesteps[None].to(sample.device)
387
+
388
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
389
+ num_frames = sample.shape[2]
390
+ timesteps = timesteps.expand(sample.shape[0])
391
+
392
+ t_emb = self.time_proj(timesteps)
393
+
394
+ # timesteps does not contain any weights and will always return f32 tensors
395
+ # but time_embedding might actually be running in fp16. so we need to cast here.
396
+ # there might be better ways to encapsulate this.
397
+ t_emb = t_emb.to(dtype=self.dtype)
398
+
399
+ emb = self.time_embedding(t_emb, timestep_cond)
400
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
401
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
402
+
403
+ # 2. pre-process
404
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
405
+ sample = self.conv_in(sample)
406
+
407
+ if num_frames > 1:
408
+ if self.gradient_checkpointing:
409
+ sample = transformer_g_c(self.transformer_in, sample, num_frames)
410
+ else:
411
+ sample = self.transformer_in(sample, num_frames=num_frames).sample
412
+
413
+ # 3. down
414
+ down_block_res_samples = (sample,)
415
+ for downsample_block in self.down_blocks:
416
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
417
+ sample, res_samples = downsample_block(
418
+ hidden_states=sample,
419
+ temb=emb,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ attention_mask=attention_mask,
422
+ num_frames=num_frames,
423
+ cross_attention_kwargs=cross_attention_kwargs,
424
+ )
425
+ else:
426
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
427
+
428
+ down_block_res_samples += res_samples
429
+
430
+ if down_block_additional_residuals is not None:
431
+ new_down_block_res_samples = ()
432
+
433
+ for down_block_res_sample, down_block_additional_residual in zip(
434
+ down_block_res_samples, down_block_additional_residuals
435
+ ):
436
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
437
+ new_down_block_res_samples += (down_block_res_sample,)
438
+
439
+ down_block_res_samples = new_down_block_res_samples
440
+
441
+ # 4. mid
442
+ if self.mid_block is not None:
443
+ sample = self.mid_block(
444
+ sample,
445
+ emb,
446
+ encoder_hidden_states=encoder_hidden_states,
447
+ attention_mask=attention_mask,
448
+ num_frames=num_frames,
449
+ cross_attention_kwargs=cross_attention_kwargs,
450
+ )
451
+
452
+ if mid_block_additional_residual is not None:
453
+ sample = sample + mid_block_additional_residual
454
+
455
+ # 5. up
456
+ for i, upsample_block in enumerate(self.up_blocks):
457
+ is_final_block = i == len(self.up_blocks) - 1
458
+
459
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
460
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
461
+
462
+ # if we have not reached the final block and need to forward the
463
+ # upsample size, we do it here
464
+ if not is_final_block and forward_upsample_size:
465
+ upsample_size = down_block_res_samples[-1].shape[2:]
466
+
467
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
468
+ sample = upsample_block(
469
+ hidden_states=sample,
470
+ temb=emb,
471
+ res_hidden_states_tuple=res_samples,
472
+ encoder_hidden_states=encoder_hidden_states,
473
+ upsample_size=upsample_size,
474
+ attention_mask=attention_mask,
475
+ num_frames=num_frames,
476
+ cross_attention_kwargs=cross_attention_kwargs,
477
+ )
478
+ else:
479
+ sample = upsample_block(
480
+ hidden_states=sample,
481
+ temb=emb,
482
+ res_hidden_states_tuple=res_samples,
483
+ upsample_size=upsample_size,
484
+ num_frames=num_frames,
485
+ )
486
+
487
+ # 6. post-process
488
+ if self.conv_norm_out:
489
+ sample = self.conv_norm_out(sample)
490
+ sample = self.conv_act(sample)
491
+
492
+ sample = self.conv_out(sample)
493
+
494
+ # reshape to (batch, channel, framerate, width, height)
495
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
496
+
497
+ if not return_dict:
498
+ return (sample,)
499
+
500
+ return UNet3DConditionOutput(sample=sample)