# CrossAttn precision handling import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") import torch from torch import nn from torch import einsum from einops import rearrange, repeat import torch from torch import nn from typing import Optional, Any from ...patches import router class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = context_dim or query_dim self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = x if context is None else context k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) # force cast to fp32 to avoid overflowing if _ATTN_PRECISION == "fp32": with torch.autocast(enabled=False, device_type="cuda"): q, k = q.float(), k.float() sim = einsum("b i d, b j d -> b i j", q, k) * self.scale else: sim = einsum("b i d, b j d -> b i j", q, k) * self.scale del q, k if mask is not None: mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) out = einsum("b i j, b j d -> b i d", sim, v) out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) class PatchedCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = context_dim or query_dim self.heads = heads self.dim_head = dim_head self.scale = dim_head**-0.5 self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, mask=None): return router.attention_forward(self, x, context, mask)