Spaces:
Runtime error
Runtime error
| # Copyright 2023 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.embeddings import ImagePositionalEmbeddings | |
| from diffusers.utils import BaseOutput, deprecate, maybe_allow_in_graph | |
| from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention | |
| from diffusers.models.embeddings import PatchEmbed | |
| from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from einops import rearrange | |
| import pdb | |
| import random | |
| if is_xformers_available(): | |
| import xformers | |
| import xformers.ops | |
| else: | |
| xformers = None | |
| class TransformerMV2DModelOutput(BaseOutput): | |
| """ | |
| The output of [`Transformer2DModel`]. | |
| Args: | |
| sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): | |
| The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability | |
| distributions for the unnoised latent pixels. | |
| """ | |
| sample: torch.FloatTensor | |
| class TransformerMV2DModel(ModelMixin, ConfigMixin): | |
| """ | |
| A 2D Transformer model for image-like data. | |
| Parameters: | |
| num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. | |
| attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. | |
| in_channels (`int`, *optional*): | |
| The number of channels in the input and output (specify if the input is **continuous**). | |
| num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. | |
| dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
| cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. | |
| sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). | |
| This is fixed during training since it is used to learn a number of position embeddings. | |
| num_vector_embeds (`int`, *optional*): | |
| The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). | |
| Includes the class for the masked latent pixel. | |
| activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. | |
| num_embeds_ada_norm ( `int`, *optional*): | |
| The number of diffusion steps used during training. Pass if at least one of the norm_layers is | |
| `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are | |
| added to the hidden states. | |
| During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. | |
| attention_bias (`bool`, *optional*): | |
| Configure if the `TransformerBlocks` attention should contain a bias parameter. | |
| """ | |
| def __init__( | |
| self, | |
| num_attention_heads: int = 16, | |
| attention_head_dim: int = 88, | |
| in_channels: Optional[int] = None, | |
| out_channels: Optional[int] = None, | |
| num_layers: int = 1, | |
| dropout: float = 0.0, | |
| norm_num_groups: int = 32, | |
| cross_attention_dim: Optional[int] = None, | |
| attention_bias: bool = False, | |
| sample_size: Optional[int] = None, | |
| num_vector_embeds: Optional[int] = None, | |
| patch_size: Optional[int] = None, | |
| activation_fn: str = "geglu", | |
| num_embeds_ada_norm: Optional[int] = None, | |
| use_linear_projection: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| norm_type: str = "layer_norm", | |
| norm_elementwise_affine: bool = True, | |
| num_views: int = 1, | |
| joint_attention: bool=False, | |
| joint_attention_twice: bool=False, | |
| multiview_attention: bool=True, | |
| cross_domain_attention: bool=False | |
| ): | |
| super().__init__() | |
| self.use_linear_projection = use_linear_projection | |
| self.num_attention_heads = num_attention_heads | |
| self.attention_head_dim = attention_head_dim | |
| inner_dim = num_attention_heads * attention_head_dim | |
| # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` | |
| # Define whether input is continuous or discrete depending on configuration | |
| self.is_input_continuous = (in_channels is not None) and (patch_size is None) | |
| self.is_input_vectorized = num_vector_embeds is not None | |
| self.is_input_patches = in_channels is not None and patch_size is not None | |
| if norm_type == "layer_norm" and num_embeds_ada_norm is not None: | |
| deprecation_message = ( | |
| f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" | |
| " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." | |
| " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" | |
| " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" | |
| " would be very nice if you could open a Pull request for the `transformer/config.json` file" | |
| ) | |
| deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) | |
| norm_type = "ada_norm" | |
| if self.is_input_continuous and self.is_input_vectorized: | |
| raise ValueError( | |
| f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" | |
| " sure that either `in_channels` or `num_vector_embeds` is None." | |
| ) | |
| elif self.is_input_vectorized and self.is_input_patches: | |
| raise ValueError( | |
| f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" | |
| " sure that either `num_vector_embeds` or `num_patches` is None." | |
| ) | |
| elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: | |
| raise ValueError( | |
| f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" | |
| f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." | |
| ) | |
| # 2. Define input layers | |
| if self.is_input_continuous: | |
| self.in_channels = in_channels | |
| self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
| if use_linear_projection: | |
| self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) | |
| else: | |
| self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) | |
| elif self.is_input_vectorized: | |
| assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" | |
| assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" | |
| self.height = sample_size | |
| self.width = sample_size | |
| self.num_vector_embeds = num_vector_embeds | |
| self.num_latent_pixels = self.height * self.width | |
| self.latent_image_embedding = ImagePositionalEmbeddings( | |
| num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width | |
| ) | |
| elif self.is_input_patches: | |
| assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" | |
| self.height = sample_size | |
| self.width = sample_size | |
| self.patch_size = patch_size | |
| self.pos_embed = PatchEmbed( | |
| height=sample_size, | |
| width=sample_size, | |
| patch_size=patch_size, | |
| in_channels=in_channels, | |
| embed_dim=inner_dim, | |
| ) | |
| # 3. Define transformers blocks | |
| self.transformer_blocks = nn.ModuleList( | |
| [ | |
| BasicMVTransformerBlock( | |
| inner_dim, | |
| num_attention_heads, | |
| attention_head_dim, | |
| dropout=dropout, | |
| cross_attention_dim=cross_attention_dim, | |
| activation_fn=activation_fn, | |
| num_embeds_ada_norm=num_embeds_ada_norm, | |
| attention_bias=attention_bias, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| norm_type=norm_type, | |
| norm_elementwise_affine=norm_elementwise_affine, | |
| num_views=num_views, | |
| joint_attention=joint_attention, | |
| joint_attention_twice=joint_attention_twice, | |
| multiview_attention=multiview_attention, | |
| cross_domain_attention=cross_domain_attention | |
| ) | |
| for d in range(num_layers) | |
| ] | |
| ) | |
| # 4. Define output layers | |
| self.out_channels = in_channels if out_channels is None else out_channels | |
| if self.is_input_continuous: | |
| # TODO: should use out_channels for continuous projections | |
| if use_linear_projection: | |
| self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) | |
| else: | |
| self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) | |
| elif self.is_input_vectorized: | |
| self.norm_out = nn.LayerNorm(inner_dim) | |
| self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) | |
| elif self.is_input_patches: | |
| self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) | |
| self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) | |
| self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| ): | |
| """ | |
| The [`Transformer2DModel`] forward method. | |
| Args: | |
| hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): | |
| Input `hidden_states`. | |
| encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): | |
| Conditional embeddings for cross attention layer. If not given, cross-attention defaults to | |
| self-attention. | |
| timestep ( `torch.LongTensor`, *optional*): | |
| Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. | |
| class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): | |
| Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in | |
| `AdaLayerZeroNorm`. | |
| encoder_attention_mask ( `torch.Tensor`, *optional*): | |
| Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: | |
| * Mask `(batch, sequence_length)` True = keep, False = discard. | |
| * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. | |
| If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format | |
| above. This bias will be added to the cross-attention scores. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain | |
| tuple. | |
| Returns: | |
| If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a | |
| `tuple` where the first element is the sample tensor. | |
| """ | |
| # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. | |
| # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. | |
| # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. | |
| # expects mask of shape: | |
| # [batch, key_tokens] | |
| # adds singleton query_tokens dimension: | |
| # [batch, 1, key_tokens] | |
| # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
| # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
| # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
| if attention_mask is not None and attention_mask.ndim == 2: | |
| # assume that mask is expressed as: | |
| # (1 = keep, 0 = discard) | |
| # convert mask into a bias that can be added to attention scores: | |
| # (keep = +0, discard = -10000.0) | |
| attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 | |
| attention_mask = attention_mask.unsqueeze(1) | |
| # convert encoder_attention_mask to a bias the same way we do for attention_mask | |
| if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: | |
| encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 | |
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
| # 1. Input | |
| if self.is_input_continuous: | |
| batch, _, height, width = hidden_states.shape | |
| residual = hidden_states | |
| hidden_states = self.norm(hidden_states) | |
| if not self.use_linear_projection: | |
| hidden_states = self.proj_in(hidden_states) | |
| inner_dim = hidden_states.shape[1] | |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) | |
| else: | |
| inner_dim = hidden_states.shape[1] | |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) | |
| hidden_states = self.proj_in(hidden_states) | |
| elif self.is_input_vectorized: | |
| hidden_states = self.latent_image_embedding(hidden_states) | |
| elif self.is_input_patches: | |
| hidden_states = self.pos_embed(hidden_states) | |
| # 2. Blocks | |
| for block in self.transformer_blocks: | |
| hidden_states = block( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| timestep=timestep, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| class_labels=class_labels, | |
| ) | |
| # 3. Output | |
| if self.is_input_continuous: | |
| if not self.use_linear_projection: | |
| hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() | |
| hidden_states = self.proj_out(hidden_states) | |
| else: | |
| hidden_states = self.proj_out(hidden_states) | |
| hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() | |
| output = hidden_states + residual | |
| elif self.is_input_vectorized: | |
| hidden_states = self.norm_out(hidden_states) | |
| logits = self.out(hidden_states) | |
| # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) | |
| logits = logits.permute(0, 2, 1) | |
| # log(p(x_0)) | |
| output = F.log_softmax(logits.double(), dim=1).float() | |
| elif self.is_input_patches: | |
| # TODO: cleanup! | |
| conditioning = self.transformer_blocks[0].norm1.emb( | |
| timestep, class_labels, hidden_dtype=hidden_states.dtype | |
| ) | |
| shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) | |
| hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] | |
| hidden_states = self.proj_out_2(hidden_states) | |
| # unpatchify | |
| height = width = int(hidden_states.shape[1] ** 0.5) | |
| hidden_states = hidden_states.reshape( | |
| shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) | |
| ) | |
| hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) | |
| output = hidden_states.reshape( | |
| shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) | |
| ) | |
| if not return_dict: | |
| return (output,) | |
| return TransformerMV2DModelOutput(sample=output) | |
| class BasicMVTransformerBlock(nn.Module): | |
| r""" | |
| A basic Transformer block. | |
| Parameters: | |
| dim (`int`): The number of channels in the input and output. | |
| num_attention_heads (`int`): The number of heads to use for multi-head attention. | |
| attention_head_dim (`int`): The number of channels in each head. | |
| dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
| cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. | |
| only_cross_attention (`bool`, *optional*): | |
| Whether to use only cross-attention layers. In this case two cross attention layers are used. | |
| double_self_attention (`bool`, *optional*): | |
| Whether to use two self-attention layers. In this case no cross attention layers are used. | |
| activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. | |
| num_embeds_ada_norm (: | |
| obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. | |
| attention_bias (: | |
| obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| dropout=0.0, | |
| cross_attention_dim: Optional[int] = None, | |
| activation_fn: str = "geglu", | |
| num_embeds_ada_norm: Optional[int] = None, | |
| attention_bias: bool = False, | |
| only_cross_attention: bool = False, | |
| double_self_attention: bool = False, | |
| upcast_attention: bool = False, | |
| norm_elementwise_affine: bool = True, | |
| norm_type: str = "layer_norm", | |
| final_dropout: bool = False, | |
| num_views: int = 1, | |
| joint_attention: bool = False, | |
| joint_attention_twice: bool = False, | |
| multiview_attention: bool = True, | |
| cross_domain_attention: bool = False | |
| ): | |
| super().__init__() | |
| self.only_cross_attention = only_cross_attention | |
| self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" | |
| self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" | |
| if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: | |
| raise ValueError( | |
| f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" | |
| f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." | |
| ) | |
| # Define 3 blocks. Each block has its own normalization layer. | |
| # 1. Self-Attn | |
| if self.use_ada_layer_norm: | |
| self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) | |
| elif self.use_ada_layer_norm_zero: | |
| self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) | |
| else: | |
| self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) | |
| self.multiview_attention = multiview_attention | |
| self.cross_domain_attention = cross_domain_attention | |
| self.attn1 = CustomAttention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| cross_attention_dim=cross_attention_dim if only_cross_attention else None, | |
| upcast_attention=upcast_attention, | |
| processor=MVAttnProcessor() | |
| ) | |
| # 2. Cross-Attn | |
| if cross_attention_dim is not None or double_self_attention: | |
| # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. | |
| # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during | |
| # the second cross attention block. | |
| self.norm2 = ( | |
| AdaLayerNorm(dim, num_embeds_ada_norm) | |
| if self.use_ada_layer_norm | |
| else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) | |
| ) | |
| self.attn2 = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=cross_attention_dim if not double_self_attention else None, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| upcast_attention=upcast_attention, | |
| ) # is self-attn if encoder_hidden_states is none | |
| else: | |
| self.norm2 = None | |
| self.attn2 = None | |
| # 3. Feed-forward | |
| self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) | |
| self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) | |
| # let chunk size default to None | |
| self._chunk_size = None | |
| self._chunk_dim = 0 | |
| self.num_views = num_views | |
| self.joint_attention = joint_attention | |
| if self.joint_attention: | |
| # Joint task -Attn | |
| self.attn_joint = CustomJointAttention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| cross_attention_dim=cross_attention_dim if only_cross_attention else None, | |
| upcast_attention=upcast_attention, | |
| processor=JointAttnProcessor() | |
| ) | |
| nn.init.zeros_(self.attn_joint.to_out[0].weight.data) | |
| self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) | |
| self.joint_attention_twice = joint_attention_twice | |
| if self.joint_attention_twice: | |
| print("joint twice") | |
| # Joint task -Attn | |
| self.attn_joint_twice = CustomJointAttention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| cross_attention_dim=cross_attention_dim if only_cross_attention else None, | |
| upcast_attention=upcast_attention, | |
| processor=JointAttnProcessor() | |
| ) | |
| nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data) | |
| self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) | |
| def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): | |
| # Sets chunk feed-forward | |
| self._chunk_size = chunk_size | |
| self._chunk_dim = dim | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| ): | |
| assert attention_mask is None # not supported yet | |
| # Notice that normalization is always applied before the real computation in the following blocks. | |
| # 1. Self-Attention | |
| if self.use_ada_layer_norm: | |
| norm_hidden_states = self.norm1(hidden_states, timestep) | |
| elif self.use_ada_layer_norm_zero: | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
| hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
| ) | |
| else: | |
| norm_hidden_states = self.norm1(hidden_states) | |
| cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
| attention_mask=attention_mask, | |
| num_views=self.num_views, | |
| multiview_attention=self.multiview_attention, | |
| cross_domain_attention=self.cross_domain_attention, | |
| **cross_attention_kwargs, | |
| ) | |
| if self.use_ada_layer_norm_zero: | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| hidden_states = attn_output + hidden_states | |
| # joint attention twice | |
| if self.joint_attention_twice: | |
| norm_hidden_states = ( | |
| self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states) | |
| ) | |
| hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states | |
| # 2. Cross-Attention | |
| if self.attn2 is not None: | |
| norm_hidden_states = ( | |
| self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) | |
| ) | |
| attn_output = self.attn2( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=encoder_attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| hidden_states = attn_output + hidden_states | |
| # 3. Feed-forward | |
| norm_hidden_states = self.norm3(hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| if self._chunk_size is not None: | |
| # "feed_forward_chunk_size" can be used to save memory | |
| if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: | |
| raise ValueError( | |
| f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." | |
| ) | |
| num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size | |
| ff_output = torch.cat( | |
| [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], | |
| dim=self._chunk_dim, | |
| ) | |
| else: | |
| ff_output = self.ff(norm_hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| hidden_states = ff_output + hidden_states | |
| if self.joint_attention: | |
| norm_hidden_states = ( | |
| self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states) | |
| ) | |
| hidden_states = self.attn_joint(norm_hidden_states) + hidden_states | |
| return hidden_states | |
| class CustomAttention(Attention): | |
| def set_use_memory_efficient_attention_xformers( | |
| self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
| ): | |
| processor = XFormersMVAttnProcessor() | |
| self.set_processor(processor) | |
| # print("using xformers attention processor") | |
| class CustomJointAttention(Attention): | |
| def set_use_memory_efficient_attention_xformers( | |
| self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
| ): | |
| processor = XFormersJointAttnProcessor() | |
| self.set_processor(processor) | |
| # print("using xformers attention processor") | |
| class MVAttnProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_views=1, | |
| multiview_attention=True | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| # print('query', query.shape, 'key', key.shape, 'value', value.shape) | |
| #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) | |
| # pdb.set_trace() | |
| # multi-view self-attention | |
| if multiview_attention: | |
| if num_views <= 6: | |
| # after use xformer; possible to train with 6 views | |
| key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) | |
| value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) | |
| else:# apply sparse attention | |
| pass | |
| # print("use sparse attention") | |
| # # seems that the sparse random sampling cause problems | |
| # # don't use random sampling, just fix the indexes | |
| # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views) | |
| # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views) | |
| # allkeys = [] | |
| # allvalues = [] | |
| # all_indexes = { | |
| # 0 : [0, 2, 3, 4], | |
| # 1: [0, 1, 3, 5], | |
| # 2: [0, 2, 3, 4], | |
| # 3: [0, 2, 3, 4], | |
| # 4: [0, 2, 3, 4], | |
| # 5: [0, 1, 3, 5] | |
| # } | |
| # for jj in range(num_views): | |
| # # valid_index = [x for x in range(0, num_views) if x!= jj] | |
| # # indexes = random.sample(valid_index, 3) + [jj] + [0] | |
| # indexes = all_indexes[jj] | |
| # indexes = torch.tensor(indexes).long().to(key.device) | |
| # allkeys.append(onekey[:, indexes]) | |
| # allvalues.append(onevalue[:, indexes]) | |
| # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1 | |
| # values = torch.stack(allvalues, dim=1) | |
| # key = rearrange(keys, 'b t f d c -> (b t) (f d) c') | |
| # value = rearrange(values, 'b t f d c -> (b t) (f d) c') | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class XFormersMVAttnProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_views=1., | |
| multiview_attention=True, | |
| cross_domain_attention=False, | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| # from yuancheng; here attention_mask is None | |
| if attention_mask is not None: | |
| # expand our mask's singleton query_tokens dimension: | |
| # [batch*heads, 1, key_tokens] -> | |
| # [batch*heads, query_tokens, key_tokens] | |
| # so that it can be added as a bias onto the attention scores that xformers computes: | |
| # [batch*heads, query_tokens, key_tokens] | |
| # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
| _, query_tokens, _ = hidden_states.shape | |
| attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key_raw = attn.to_k(encoder_hidden_states) | |
| value_raw = attn.to_v(encoder_hidden_states) | |
| # print('query', query.shape, 'key', key.shape, 'value', value.shape) | |
| #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) | |
| # pdb.set_trace() | |
| # multi-view self-attention | |
| if multiview_attention: | |
| key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) | |
| value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) | |
| if cross_domain_attention: | |
| # memory efficient, cross domain attention | |
| key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c | |
| value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2) | |
| key_cross = torch.concat([key_1, key_0], dim=0) | |
| value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c | |
| key = torch.cat([key, key_cross], dim=1) | |
| value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c | |
| else: | |
| # print("don't use multiview attention.") | |
| key = key_raw | |
| value = value_raw | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class XFormersJointAttnProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_tasks=2 | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| # from yuancheng; here attention_mask is None | |
| if attention_mask is not None: | |
| # expand our mask's singleton query_tokens dimension: | |
| # [batch*heads, 1, key_tokens] -> | |
| # [batch*heads, query_tokens, key_tokens] | |
| # so that it can be added as a bias onto the attention scores that xformers computes: | |
| # [batch*heads, query_tokens, key_tokens] | |
| # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
| _, query_tokens, _ = hidden_states.shape | |
| attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| assert num_tasks == 2 # only support two tasks now | |
| key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c | |
| value_0, value_1 = torch.chunk(value, dim=0, chunks=2) | |
| key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c | |
| value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c | |
| key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c | |
| value = torch.cat([value]*2, dim=0) # (2 b t) 2d c | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class JointAttnProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_tasks=2 | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| assert num_tasks == 2 # only support two tasks now | |
| key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c | |
| value_0, value_1 = torch.chunk(value, dim=0, chunks=2) | |
| key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c | |
| value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c | |
| key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c | |
| value = torch.cat([value]*2, dim=0) # (2 b t) 2d c | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states |