Spaces:
Running
Running
| import torch | |
| from torch import nn, Tensor | |
| from einops import rearrange | |
| from concept_attention.flux.src.flux.modules.layers import Modulation, QKNorm | |
| from concept_attention.flux.src.flux.math import attention | |
| NUM_IMAGE_PATCHES = 4096 | |
| class ModifiedSingleStreamBlock(nn.Module): | |
| """ | |
| A DiT block with parallel linear layers as described in | |
| https://arxiv.org/abs/2302.05442 and adapted modulation interface. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| qk_scale: float | None = None | |
| ): | |
| super().__init__() | |
| self.hidden_dim = hidden_size | |
| self.num_heads = num_heads | |
| head_dim = hidden_size // num_heads | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| # qkv and mlp_in | |
| self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) | |
| # proj and mlp_out | |
| self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) | |
| self.norm = QKNorm(head_dim) | |
| self.hidden_size = hidden_size | |
| self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.mlp_act = nn.GELU(approximate="tanh") | |
| self.modulation = Modulation(hidden_size, double=False) | |
| def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: | |
| mod, _ = self.modulation(vec) | |
| # Perform img-text self attention | |
| x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift | |
| qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) | |
| q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) | |
| q, k = self.norm(q, k, v) | |
| # compute attention | |
| attn = attention(q, k, v, pe=pe) | |
| # compute activation in mlp stream, cat again and run second linear layer | |
| output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) | |
| return x + mod.gate * output | |