import math import os import time from typing import Literal import torch import torch.nn.functional as F from einops import rearrange, reduce, repeat from fla.models.utils import Cache from torch import nn from transformers.cache_utils import Cache def apply_causal_sliding_window(mask: torch.Tensor, window_size: int) -> torch.Tensor: B, H, Q, KV = mask.shape device = mask.device q_idx = torch.arange(Q, device=device).unsqueeze(1) # (Q, 1) k_idx = torch.arange(KV, device=device).unsqueeze(0) # (1, KV) lower_bound = q_idx - (window_size - 1) # (Q, 1), may be negative allowed_2d = (k_idx <= q_idx) & (k_idx >= lower_bound) # (Q, KV), dtype=torch.bool allowed_4d = allowed_2d.unsqueeze(0).unsqueeze(0).expand(B, H, Q, KV) orig_dtype = mask.dtype if mask.dtype != torch.bool: mask_bool = mask.to(torch.bool) else: mask_bool = mask new_mask = mask_bool & allowed_4d if orig_dtype != torch.bool: return new_mask.to(orig_dtype) else: return new_mask def precompute_freqs_cis_( t: torch.Tensor, n_elem: int, base: float = 10000, ) -> torch.Tensor: freqs = 1.0 / ( base ** ( torch.arange(0, n_elem, 2, device=t.device)[: (n_elem // 2)].float() / n_elem ) ) freqs = torch.outer(t, freqs) cache = repeat(freqs, "... d -> ... (d 2)") return cache import torch from einops import repeat def precompute_freqs_cis( t: torch.Tensor, # shape: (B, T) or (T,) n_elem: int, base: float = 10000, ) -> torch.Tensor: """ Batched version of precompute_freqs_cis. Args: t: torch.Tensor, shape (B, T) or (T,) Timesteps to compute frequencies for. n_elem: int Embedding dimension (must be even). base: float Base for frequency computation (default: 10000). Returns: cache: torch.Tensor, shape (B, T, n_elem) if batched, (T, n_elem) if unbatched. """ if t.dim() == 1: # unbatched t = t.unsqueeze(0) # (1, T) B, T = t.shape device = t.device # frequencies (half dimension, then expand back) freqs = 1.0 / ( base ** (torch.arange(0, n_elem, 2, device=device)[: (n_elem // 2)].float() / n_elem) ) # shape: (n_elem // 2,) # outer product for each batch # (B, T, n_elem//2) freqs = torch.einsum("bt,d->btd", t, freqs) # duplicate last dim to interleave sin/cos pairs # (B, T, n_elem) cache = repeat(freqs, "... d -> ... (d 2)") # if cache.shape[0] == 1: # if originally unbatched # cache = cache.squeeze(0) # (T, n_elem) return cache def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: out = x * freqs_cis.cos() + rotate_half(x) * freqs_cis.sin() return out def scaled_dot_product_attention(query, key, value, mask=None): scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor if mask is not None: attn_weight.masked_fill_(~mask, -torch.finfo(attn_weight.dtype).max) attn_weight = torch.softmax(attn_weight, dim=-1) return attn_weight @ value, attn_weight class SelfAttention(nn.Module): def __init__( self, dim: int, num_heads: int, layer_idx: int, is_causal: bool = False, sliding_window: int | None = None, ): super().__init__() self.qkv = nn.Linear(dim, 3 * dim) assert dim % num_heads == 0 self.heads = num_heads self.is_causal = is_causal self.layer_idx = layer_idx self.output_proj = nn.Linear(dim, dim) self.sliding_window = sliding_window if self.sliding_window is not None: self.is_causal = False def forward( self, x, freqs: torch.Tensor | None = None, mask: torch.Tensor | None = None, cache: Cache | None = None, ): B, T, D = x.shape q, k, v = self.qkv(x).chunk(3, dim=-1) q, k, v = map( lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v) ) if freqs is not None: q = apply_rotary_emb(q, freqs) k = apply_rotary_emb(k, freqs) if cache is not None: cache.update(attn_state=(k, v), layer_idx=self.layer_idx, offset=T) k, v = cache[self.layer_idx]["attn_state"] if self.sliding_window is not None: mask = torch.ones(B, 1, T, T, device=x.device) mask = apply_causal_sliding_window(mask, self.sliding_window) y = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, is_causal=self.is_causal and T > 1 ) y = rearrange(y, "b h n d -> b n (h d)") y = self.output_proj(y) return y class CrossAttention(nn.Module): def __init__( self, dim: int, num_heads: int, layer_idx: int | None = None, dropout: float = 0.1, ): super().__init__() assert dim % num_heads == 0 self.pre_norm_q = nn.LayerNorm(dim) self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.layer_idx = layer_idx self.heads = num_heads self.dropout_att = dropout def _prepare_kv(self, text_hidden_states: torch.Tensor): v = self.ln_v(self.v(text_hidden_states)) k = self.ln_k(self.k(text_hidden_states)) def _query(self, x): return self.q(self.pre_norm_q(q)) def forward( self, q: torch.Tensor, k: torch.Tensor | None = None, v: torch.Tensor | None = None, mask: torch.Tensor | None = None, output_attention: bool = False, cache: Cache | None = None, **kwargs, ): if v is None: v = k q = self.q(self.pre_norm_q(q)) if cache is not None: if cache[self.layer_idx] is not None: ca_state = cache[self.layer_idx]["crossatt_state"] if ca_state is not None: k, v = ca_state else: v = self.v(v) k = self.k(k) cache.update(crossatt_state=(k, v), layer_idx=self.layer_idx) else: v = self.v(v) k = self.k(k) q, k, v = map( lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v) ) if mask is not None: if mask.ndim == 3: mask = mask[:, None] # if not self.training: if not self.training: x, att = scaled_dot_product_attention(q, k, v, mask=mask) else: x = nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=self.dropout_att ) att = None x = rearrange(x, "b h n d -> b n (h d)") if att is not None: if cache is not None: cache.update(crossatt_weights=att, layer_idx=self.layer_idx) else: self.att = att return x class ConvPos(nn.Module): def __init__(self, dim, max_seq_len=1000, kernel_size=7, n_parallel_codebook=2): super().__init__() self.embed = nn.Embedding(max_seq_len * n_parallel_codebook, dim) self.dw_conv = nn.Conv1d(dim, dim, kernel_size, groups=dim, padding="same") self.max_seq_len = max_seq_len self.n_parallel_codebook = n_parallel_codebook def forward(self, x, left_shift=0, random_shift=False): # left_pad = 31 if left_shift > 0 else 0 # x = torch.cat((torch.arange(left_shift - left_pad, left_shift).to(x).unsqueeze(0),x, torch.arange(31).to(x).unsqueeze(0)), dim=1).clamp_min_(0) if random_shift: bias = torch.randint( 0, self.n_parallel_codebook, (x.shape[0],), device=x.device, ) x = x + bias * self.max_seq_len y = self.embed(x) y = rearrange(y, "b n c -> b c n") y = self.dw_conv(y) y = rearrange(y, "b c n -> b n c") # [:,left_pad:-31] return y class SinPos(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): exp = torch.arange(self.dim // 2, device=x.device) exp = 2 * exp / (self.dim) exp = rearrange(exp, "e -> 1 1 e") x = rearrange(x, "b p -> b p 1") pos = x * torch.pow(10000, -exp) pos = torch.cat((pos, pos + math.pi / 2), dim=2) pos = torch.sin(pos) return pos class BlindCrossAttention(nn.Module): def __init__( self, q_dim, k_dim, att_dim, pos_net, dropout=0.1, pos_dim=64, pos_type="sinusoidal", layer_idx: int | None = None, ): super().__init__() self.q = nn.Linear(q_dim, att_dim) self.k = nn.Linear(k_dim, att_dim) self.v = nn.Linear(k_dim, att_dim) self.pos_net = pos_net if pos_type == "sinusoidal": self.pos_embed = SinPos(pos_dim) elif pos_type == "convolutional": self.pos_embed = ConvPos(pos_dim) self.ln_q = nn.LayerNorm(att_dim) self.ln_k = nn.LayerNorm(att_dim) self.ln_v = nn.LayerNorm(att_dim) self.dropout_att = nn.Dropout(dropout) self.layer_idx = layer_idx def _prepare_kv(self, text_hidden_states: torch.Tensor): v = self.ln_v(self.v(text_hidden_states)) k = self.ln_k(self.k(text_hidden_states)) b, h, j, d = k.shape pos = torch.arange(j, device=k.device).unsqueeze(0) pos_emb = self.pos_embed(pos) return {"k": k, "v": v, "pos_emb": pos_emb} def _query(self, x): return self.ln_q(self.q(x)) def forward( self, q, k, kv_cached=None, mask=None, time_step=None, pos=None, left_shift=0, past_key_values=None, cache=None, **kwargs, ): q = self.ln_q(self.q(q)) # if kv_cached is None: # v = self.ln_v(self.v(k)) # k = self.ln_k(self.k(k)) # else: # k, v = kv_cached if mask is not None: mask = mask.unsqueeze(1) if cache is not None: if cache[self.layer_idx] is not None: ca_state = cache[self.layer_idx]["crossatt_state"] if ca_state is not None: k, v, pos_emb = ca_state else: # v = self.v(v) # k = self.k(k) v = self.ln_v(self.v(k)) k = self.ln_k(self.k(k)) pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0) pos_emb = self.pos_embed(pos, left_shift=left_shift) cache.update( crossatt_state=(k, v, pos_emb), layer_idx=self.layer_idx ) else: v = self.ln_v(self.v(k)) k = self.ln_k(self.k(k)) if pos is None: pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0) pos_emb = self.pos_embed(pos, left_shift=left_shift) q, k, v = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k, v)) b, h, j, d = k.shape if self.training: sdpa = lambda q, k, pos: ( nn.functional.scaled_dot_product_attention( q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p ), None, ) else: sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask) x, att1 = sdpa(q, k, pos_emb.unsqueeze(1)) x = rearrange(x, "b 1 n d -> b n d") x = self.pos_net(x, cache=cache) x = rearrange(x, "b n d -> b 1 n d") pos_emb = rearrange(pos_emb, "b n d -> b 1 n d") x, att2 = sdpa(x, pos_emb, v) x = rearrange(x, "b 1 n d -> b n d") self.att1 = att1 self.att2 = att2 if att2 is not None: if cache is not None: cache.update( crossatt_weights=torch.cat((att1, att2), dim=1), layer_idx=self.layer_idx, ) return x class ListenReadCrossAttention(nn.Module): def __init__( self, q_dim: int, k_dim: int, att_dim: int, crossatt_type: Literal["listen", "read"], num_heads: int = 1, dropout: float = 0.1, layer_idx: int | None = None, ): super().__init__() self.q = nn.Linear(q_dim, att_dim) self.k = nn.Linear(k_dim, att_dim) self.ln_q = nn.LayerNorm(att_dim) self.ln_k = nn.LayerNorm(att_dim) self.dropout_att = nn.Dropout(dropout) self.crossatt_type = crossatt_type self.layer_idx = layer_idx def forward( self, q: torch.Tensor, k: torch.Tensor, text_freqs: torch.Tensor, mask: torch.Tensor | None = None, past_key_values=None, cache=None, **kwargs, ): q = self.ln_q(self.q(q)) k = self.ln_k(self.k(k)) if mask is not None: mask = mask.unsqueeze(1) q, k = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k)) if self.training: sdpa = lambda q, k, pos: ( nn.functional.scaled_dot_product_attention( q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p ), None, ) else: sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask) text_freqs = rearrange(text_freqs, "b n d -> b 1 n d") if self.crossatt_type == "listen": x, att = sdpa(q, k, text_freqs) elif self.crossatt_type == "read": x, att = sdpa(q, text_freqs, k) else: raise ValueError x = rearrange(x, "b 1 n d -> b n d") if att is not None: if cache is not None: cache.update( crossatt_weights=att, layer_idx=self.layer_idx, ) self.att = att return x