import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange torch.backends.cuda.enable_mem_efficient_sdp(True) def create_sin_embedding(positions, dim, max_period=10000 ): # assert dim % 2 == 0 half_dim = dim // 2 positions = positions.to(torch.float) adim = torch.arange(half_dim, device=positions.device, dtype=torch.float).view(1, 1, -1) max_period_tensor = torch.full([], max_period, device=positions.device, dtype=torch.float) # avoid sync point phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) # OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16 return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) class StreamingMultiheadAttention(nn.Module): def __init__(self, embed_dim, num_heads, cross_attention=False, ): super().__init__() self.cross_attention = cross_attention # if not self.cross_attention then it has kvcachingn self.k_history = None # cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history self.v_history = None self.num_heads = num_heads self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim), dtype=torch.float)) def forward(self, query, key=None, value=None): layout = "b h t d" if self.cross_attention: # Different queries, keys, values > split in_proj_weight dim = self.in_proj_weight.shape[0] // 3 q = nn.functional.linear(query, self.in_proj_weight[:dim]) k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim]) v = nn.functional.linear(value, self.in_proj_weight[2 * dim:]) q, k, v = [ rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] else: # 1st projected makes k,v (instantaneous) # Here else is self_attention for audio with itself (above is cross attention txt) # HISTORY - DIFFERENT FOR EACH TRANSF LAYER # here we have different floating values from official projected = nn.functional.linear(query, self.in_proj_weight, None) # print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values bound_layout = "b h p t d" packed = rearrange( projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) q, k, v = packed.unbind(dim=2) if self.k_history is not None: # IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v # thus it will try to continue with incompatible k/v dims! self.k_history = torch.cat([self.k_history, k], 2) self.v_history = torch.cat([self.v_history, v], 2) else: self.k_history = k self.v_history = v # Assign Completed k / v to k / v k = self.k_history v = self.v_history # -> kv CACHE ONLY APPLIES if not self.cross_attention x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0) x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) x = self.out_proj(x) return x class StreamingTransformerLayer(nn.Module): def __init__(self, d_model, num_heads, dim_feedforward): super().__init__() self.self_attn = StreamingMultiheadAttention(embed_dim=d_model, num_heads=num_heads) self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model, num_heads=num_heads, cross_attention=True) self.norm_cross = nn.LayerNorm(d_model, eps=1e-5) self.norm1 = nn.LayerNorm(d_model, eps=1e-5) self.norm2 = nn.LayerNorm(d_model, eps=1e-5) def forward(self, x, cross_attention_src=None): x = x + self.self_attn(self.norm1(x)) x = x + self.cross_attention(query=self.norm_cross(x), key=cross_attention_src, value=cross_attention_src) # txtcondition x = x + self.linear2(F.gelu(self.linear1(self.norm2(x)))) return x class StreamingTransformer(nn.Module): def __init__(self, d_model=1536, num_heads=24, num_layers=48, dim_feedforward=6144): super().__init__() self.layers = nn.ModuleList( [ StreamingTransformerLayer(d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward) for _ in range(num_layers) ] ) def forward(self, x, cache_position=None, cross_attention_src=None): x = x + create_sin_embedding( torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536) for lay in self.layers: x = lay(x, cross_attention_src=cross_attention_src) return x def _flush(self, n_preserve=None): for lay in self.layers: if n_preserve is not None: # cache position is difficult to choose to also preserve kv from end lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :] lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :] else: lay.self_attn.k_history = None lay.self_attn.v_history = None