|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | from functools import partial | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from einops import rearrange, repeat | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | from flash_attn import ( | 
					
						
						|  | flash_attn_kvpacked_func, | 
					
						
						|  | flash_attn_qkvpacked_func, | 
					
						
						|  | flash_attn_varlen_kvpacked_func, | 
					
						
						|  | flash_attn_varlen_qkvpacked_func, | 
					
						
						|  | flash_attn_with_kvcache, | 
					
						
						|  | ) | 
					
						
						|  | except ImportError: | 
					
						
						|  | flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None | 
					
						
						|  | flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None | 
					
						
						|  | flash_attn_with_kvcache = None | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear | 
					
						
						|  | except ImportError: | 
					
						
						|  | FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlashSelfAttention(nn.Module): | 
					
						
						|  | """Implement the scaled dot product attention with softmax. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | softmax_scale: The temperature to use for the softmax attention. | 
					
						
						|  | (default: 1/sqrt(d_keys) where d_keys is computed at | 
					
						
						|  | runtime) | 
					
						
						|  | attention_dropout: The dropout rate to apply to the attention | 
					
						
						|  | (default: 0.0) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | causal=False, | 
					
						
						|  | softmax_scale=None, | 
					
						
						|  | attention_dropout=0.0, | 
					
						
						|  | window_size=(-1, -1), | 
					
						
						|  | deterministic=False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" | 
					
						
						|  | assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" | 
					
						
						|  | self.causal = causal | 
					
						
						|  | self.softmax_scale = softmax_scale | 
					
						
						|  | self.drop = nn.Dropout(attention_dropout) | 
					
						
						|  | self.window_size = window_size | 
					
						
						|  | self.deterministic = deterministic | 
					
						
						|  |  | 
					
						
						|  | def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): | 
					
						
						|  | """Implements the multihead softmax attention. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | qkv: The tensor containing the query, key, and value. | 
					
						
						|  | If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). | 
					
						
						|  | If cu_seqlens is not None and max_seqlen is not None, then qkv has shape | 
					
						
						|  | (total, 3, H, D), where total is the sum of the sequence lengths in the batch. | 
					
						
						|  | causal: if passed, will override self.causal | 
					
						
						|  | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | 
					
						
						|  | of the sequences in the batch, used to index into qkv. | 
					
						
						|  | max_seqlen: int. Maximum sequence length in the batch. | 
					
						
						|  | Returns: | 
					
						
						|  | -------- | 
					
						
						|  | out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, | 
					
						
						|  | else (B, S, H, D). | 
					
						
						|  | """ | 
					
						
						|  | assert qkv.dtype in [torch.float16, torch.bfloat16] | 
					
						
						|  | assert qkv.is_cuda | 
					
						
						|  | causal = self.causal if causal is None else causal | 
					
						
						|  | unpadded = cu_seqlens is not None | 
					
						
						|  |  | 
					
						
						|  | if unpadded: | 
					
						
						|  | assert cu_seqlens.dtype == torch.int32 | 
					
						
						|  | assert max_seqlen is not None | 
					
						
						|  | assert isinstance(max_seqlen, int) | 
					
						
						|  | return flash_attn_varlen_qkvpacked_func( | 
					
						
						|  | qkv, | 
					
						
						|  | cu_seqlens, | 
					
						
						|  | max_seqlen, | 
					
						
						|  | self.drop.p if self.training else 0.0, | 
					
						
						|  | softmax_scale=self.softmax_scale, | 
					
						
						|  | causal=causal, | 
					
						
						|  | alibi_slopes=None, | 
					
						
						|  | window_size=self.window_size, | 
					
						
						|  | deterministic=self.deterministic, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | return flash_attn_qkvpacked_func( | 
					
						
						|  | qkv, | 
					
						
						|  | self.drop.p if self.training else 0.0, | 
					
						
						|  | softmax_scale=self.softmax_scale, | 
					
						
						|  | causal=causal, | 
					
						
						|  | alibi_slopes=None, | 
					
						
						|  | window_size=self.window_size, | 
					
						
						|  | deterministic=self.deterministic, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlashCrossAttention(nn.Module): | 
					
						
						|  | """Implement the scaled dot product attention with softmax. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | softmax_scale: The temperature to use for the softmax attention. | 
					
						
						|  | (default: 1/sqrt(d_keys) where d_keys is computed at | 
					
						
						|  | runtime) | 
					
						
						|  | attention_dropout: The dropout rate to apply to the attention | 
					
						
						|  | (default: 0.0) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | causal=False, | 
					
						
						|  | softmax_scale=None, | 
					
						
						|  | attention_dropout=0.0, | 
					
						
						|  | window_size=(-1, -1), | 
					
						
						|  | deterministic=False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" | 
					
						
						|  | assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" | 
					
						
						|  | self.causal = causal | 
					
						
						|  | self.softmax_scale = softmax_scale | 
					
						
						|  | self.drop = nn.Dropout(attention_dropout) | 
					
						
						|  | self.window_size = window_size | 
					
						
						|  | self.deterministic = deterministic | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | q, | 
					
						
						|  | kv, | 
					
						
						|  | causal=None, | 
					
						
						|  | cu_seqlens=None, | 
					
						
						|  | max_seqlen=None, | 
					
						
						|  | cu_seqlens_k=None, | 
					
						
						|  | max_seqlen_k=None, | 
					
						
						|  | ): | 
					
						
						|  | """Implements the multihead softmax attention. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | q: The tensor containing the query. (B, Sq, H, D) | 
					
						
						|  | kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) | 
					
						
						|  | causal: if passed, will override self.causal | 
					
						
						|  | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | 
					
						
						|  | of the sequences in the batch, used to index into q. | 
					
						
						|  | max_seqlen: int. Maximum sequence length in the batch of q. | 
					
						
						|  | cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | 
					
						
						|  | of the sequences in the batch, used to index into kv. | 
					
						
						|  | max_seqlen_k: int. Maximum sequence length in the batch of k and v. | 
					
						
						|  | """ | 
					
						
						|  | assert q.dtype in [torch.float16, torch.bfloat16] | 
					
						
						|  | assert q.is_cuda and kv.is_cuda | 
					
						
						|  | causal = self.causal if causal is None else causal | 
					
						
						|  | unpadded = cu_seqlens is not None | 
					
						
						|  |  | 
					
						
						|  | if unpadded: | 
					
						
						|  | assert cu_seqlens.dtype == torch.int32 | 
					
						
						|  | assert max_seqlen is not None | 
					
						
						|  | assert isinstance(max_seqlen, int) | 
					
						
						|  | assert cu_seqlens_k is not None | 
					
						
						|  | assert cu_seqlens_k.dtype == torch.int32 | 
					
						
						|  | assert max_seqlen_k is not None | 
					
						
						|  | assert isinstance(max_seqlen, int) | 
					
						
						|  | return flash_attn_varlen_kvpacked_func( | 
					
						
						|  | q, | 
					
						
						|  | kv, | 
					
						
						|  | cu_seqlens, | 
					
						
						|  | cu_seqlens_k, | 
					
						
						|  | max_seqlen, | 
					
						
						|  | max_seqlen_k, | 
					
						
						|  | self.drop.p if self.training else 0.0, | 
					
						
						|  | softmax_scale=self.softmax_scale, | 
					
						
						|  | causal=causal, | 
					
						
						|  | alibi_slopes=None, | 
					
						
						|  | window_size=self.window_size, | 
					
						
						|  | deterministic=self.deterministic, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | batch_size, seqlen_q = q.shape[0], q.shape[1] | 
					
						
						|  | seqlen_k = kv.shape[1] | 
					
						
						|  | assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] | 
					
						
						|  | return flash_attn_kvpacked_func( | 
					
						
						|  | q, | 
					
						
						|  | kv, | 
					
						
						|  | self.drop.p if self.training else 0.0, | 
					
						
						|  | causal=causal, | 
					
						
						|  | softmax_scale=self.softmax_scale, | 
					
						
						|  | alibi_slopes=None, | 
					
						
						|  | window_size=self.window_size, | 
					
						
						|  | deterministic=self.deterministic, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SelfAttention(nn.Module): | 
					
						
						|  | """Implement the scaled dot product attention with softmax. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | softmax_scale: The temperature to use for the softmax attention. | 
					
						
						|  | (default: 1/sqrt(d_keys) where d_keys is computed at | 
					
						
						|  | runtime) | 
					
						
						|  | attention_dropout: The dropout rate to apply to the attention | 
					
						
						|  | (default: 0.0) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.causal = causal | 
					
						
						|  | self.softmax_scale = softmax_scale | 
					
						
						|  | self.drop = nn.Dropout(attention_dropout) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, qkv, causal=None, key_padding_mask=None): | 
					
						
						|  | """Implements the multihead softmax attention. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) | 
					
						
						|  | causal: if passed, will override self.causal | 
					
						
						|  | key_padding_mask: boolean mask to apply to the attention weights. True means to keep, | 
					
						
						|  | False means to mask out. (B, S) | 
					
						
						|  | """ | 
					
						
						|  | batch_size, seqlen = qkv.shape[0], qkv.shape[1] | 
					
						
						|  | causal = self.causal if causal is None else causal | 
					
						
						|  | q, k, v = qkv.unbind(dim=2) | 
					
						
						|  | softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) | 
					
						
						|  | scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) | 
					
						
						|  | if key_padding_mask is not None: | 
					
						
						|  | padding_mask = torch.full( | 
					
						
						|  | (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device | 
					
						
						|  | ) | 
					
						
						|  | padding_mask.masked_fill_(key_padding_mask, 0.0) | 
					
						
						|  |  | 
					
						
						|  | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") | 
					
						
						|  | if causal: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | causal_mask = torch.triu( | 
					
						
						|  | torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | scores = scores + causal_mask.to(dtype=scores.dtype) | 
					
						
						|  | attention = torch.softmax(scores, dim=-1, dtype=v.dtype) | 
					
						
						|  | attention_drop = self.drop(attention) | 
					
						
						|  | output = torch.einsum("bhts,bshd->bthd", attention_drop, v) | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CrossAttention(nn.Module): | 
					
						
						|  | """Implement the scaled dot product attention with softmax. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | softmax_scale: The temperature to use for the softmax attention. | 
					
						
						|  | (default: 1/sqrt(d_keys) where d_keys is computed at | 
					
						
						|  | runtime) | 
					
						
						|  | attention_dropout: The dropout rate to apply to the attention | 
					
						
						|  | (default: 0.0) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.causal = causal | 
					
						
						|  | self.softmax_scale = softmax_scale | 
					
						
						|  | self.drop = nn.Dropout(attention_dropout) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, q, kv, causal=None, key_padding_mask=None): | 
					
						
						|  | """Implements the multihead softmax attention. | 
					
						
						|  | Arguments | 
					
						
						|  | --------- | 
					
						
						|  | q: The tensor containing the query. (B, Sq, H, D) | 
					
						
						|  | kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) | 
					
						
						|  | causal: if passed, will override self.causal | 
					
						
						|  | key_padding_mask: boolean mask to apply to the attention weights. True means to keep, | 
					
						
						|  | False means to mask out. (B, Sk) | 
					
						
						|  | """ | 
					
						
						|  | batch_size, seqlen_q = q.shape[0], q.shape[1] | 
					
						
						|  | causal = self.causal if causal is None else causal | 
					
						
						|  | seqlen_k = kv.shape[1] | 
					
						
						|  | assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] | 
					
						
						|  | if kv.shape[3] != q.shape[2]: | 
					
						
						|  | kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) | 
					
						
						|  | k, v = kv.unbind(dim=2) | 
					
						
						|  | softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) | 
					
						
						|  | scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) | 
					
						
						|  | if key_padding_mask is not None: | 
					
						
						|  | padding_mask = torch.full( | 
					
						
						|  | (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device | 
					
						
						|  | ) | 
					
						
						|  | padding_mask.masked_fill_(key_padding_mask, 0.0) | 
					
						
						|  |  | 
					
						
						|  | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") | 
					
						
						|  | if causal: | 
					
						
						|  |  | 
					
						
						|  | row_idx = rearrange( | 
					
						
						|  | torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" | 
					
						
						|  | ) | 
					
						
						|  | col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) | 
					
						
						|  | sk = ( | 
					
						
						|  | seqlen_k | 
					
						
						|  | if key_padding_mask is None | 
					
						
						|  | else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") | 
					
						
						|  | ) | 
					
						
						|  | causal_mask = col_idx > row_idx + sk - seqlen_q | 
					
						
						|  | scores = scores.masked_fill(causal_mask, -10000.0) | 
					
						
						|  | attention = torch.softmax(scores, dim=-1, dtype=v.dtype) | 
					
						
						|  | attention_drop = self.drop(attention) | 
					
						
						|  | output = torch.einsum("bhts,bshd->bthd", attention_drop, v) | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LinearResidual(nn.Linear): | 
					
						
						|  | """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | return super().forward(input), input | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _update_kv_cache(kv, inference_params, layer_idx): | 
					
						
						|  | """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" | 
					
						
						|  |  | 
					
						
						|  | num_heads, head_dim = kv.shape[-2:] | 
					
						
						|  | if layer_idx not in inference_params.key_value_memory_dict: | 
					
						
						|  | kv_cache = torch.empty( | 
					
						
						|  | inference_params.max_batch_size, | 
					
						
						|  | inference_params.max_seqlen, | 
					
						
						|  | 2, | 
					
						
						|  | num_heads, | 
					
						
						|  | head_dim, | 
					
						
						|  | dtype=kv.dtype, | 
					
						
						|  | device=kv.device, | 
					
						
						|  | ) | 
					
						
						|  | inference_params.key_value_memory_dict[layer_idx] = kv_cache | 
					
						
						|  | else: | 
					
						
						|  | kv_cache = inference_params.key_value_memory_dict[layer_idx] | 
					
						
						|  |  | 
					
						
						|  | batch_start = inference_params.batch_size_offset | 
					
						
						|  | batch_end = batch_start + kv.shape[0] | 
					
						
						|  | sequence_start = inference_params.seqlen_offset | 
					
						
						|  | sequence_end = sequence_start + kv.shape[1] | 
					
						
						|  | assert batch_end <= kv_cache.shape[0] | 
					
						
						|  | assert sequence_end <= kv_cache.shape[1] | 
					
						
						|  | assert kv_cache is not None | 
					
						
						|  | kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv | 
					
						
						|  | return kv_cache[batch_start:batch_end, :sequence_end, ...] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MHA(nn.Module): | 
					
						
						|  | """Multi-head self-attention and cross-attention""" | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | embed_dim, | 
					
						
						|  | num_heads, | 
					
						
						|  | num_heads_kv=None, | 
					
						
						|  | cross_attn=False, | 
					
						
						|  | qkv_proj_bias=True, | 
					
						
						|  | out_proj_bias=True, | 
					
						
						|  | dropout=0.0, | 
					
						
						|  | softmax_scale=None, | 
					
						
						|  | causal=False, | 
					
						
						|  | layer_idx=None, | 
					
						
						|  | dwconv=False, | 
					
						
						|  | window_size=(-1, -1), | 
					
						
						|  | fused_bias_fc=False, | 
					
						
						|  | use_flash_attn=False, | 
					
						
						|  | return_residual=False, | 
					
						
						|  | checkpointing=False, | 
					
						
						|  | device=None, | 
					
						
						|  | dtype=None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | """ | 
					
						
						|  | num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. | 
					
						
						|  | return_residual: whether to return the input x along with the output. This is for | 
					
						
						|  | performance reason: for post-norm architecture, returning the input allows us | 
					
						
						|  | to fuse the backward of nn.Linear with the residual connection. | 
					
						
						|  | """ | 
					
						
						|  | factory_kwargs = {"device": device, "dtype": dtype} | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.embed_dim = embed_dim | 
					
						
						|  | self.cross_attn = cross_attn | 
					
						
						|  | self.causal = causal | 
					
						
						|  | self.layer_idx = layer_idx | 
					
						
						|  | self.dwconv = dwconv | 
					
						
						|  | self.use_flash_attn = use_flash_attn | 
					
						
						|  | self.return_residual = return_residual | 
					
						
						|  | self.checkpointing = checkpointing | 
					
						
						|  |  | 
					
						
						|  | if window_size != (-1, -1): | 
					
						
						|  | assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" | 
					
						
						|  |  | 
					
						
						|  | self.num_heads = num_heads | 
					
						
						|  | self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads | 
					
						
						|  | assert ( | 
					
						
						|  | self.num_heads % self.num_heads_kv == 0 | 
					
						
						|  | ), "num_heads must be divisible by num_heads_kv" | 
					
						
						|  | assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" | 
					
						
						|  | self.head_dim = self.embed_dim // num_heads | 
					
						
						|  | qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) | 
					
						
						|  | kv_dim = 2 * self.head_dim * self.num_heads_kv | 
					
						
						|  |  | 
					
						
						|  | if fused_bias_fc and FusedDense is None: | 
					
						
						|  | raise ImportError("fused_dense is not installed") | 
					
						
						|  | linear_cls = nn.Linear if not fused_bias_fc else FusedDense | 
					
						
						|  | linear_resid_cls = ( | 
					
						
						|  | LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) | 
					
						
						|  | ) | 
					
						
						|  | wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls | 
					
						
						|  | inner_attn_cls = ( | 
					
						
						|  | partial(FlashSelfAttention, window_size=window_size) | 
					
						
						|  | if use_flash_attn | 
					
						
						|  | else SelfAttention | 
					
						
						|  | ) | 
					
						
						|  | inner_cross_attn_cls = ( | 
					
						
						|  | partial(FlashCrossAttention, window_size=window_size) | 
					
						
						|  | if use_flash_attn | 
					
						
						|  | else CrossAttention | 
					
						
						|  | ) | 
					
						
						|  | if not self.cross_attn: | 
					
						
						|  | self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) | 
					
						
						|  | else: | 
					
						
						|  | self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) | 
					
						
						|  | self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) | 
					
						
						|  | if self.dwconv: | 
					
						
						|  | if self.num_heads_kv == self.num_heads: | 
					
						
						|  | self.dwconv_qkv = nn.Conv1d( | 
					
						
						|  | qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.dwconv_q = nn.Conv1d( | 
					
						
						|  | embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim | 
					
						
						|  | ) | 
					
						
						|  | self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) | 
					
						
						|  | self.inner_attn = inner_attn_cls( | 
					
						
						|  | causal=causal, | 
					
						
						|  | softmax_scale=softmax_scale, | 
					
						
						|  | attention_dropout=dropout, | 
					
						
						|  | ) | 
					
						
						|  | self.inner_cross_attn = inner_cross_attn_cls( | 
					
						
						|  | causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout | 
					
						
						|  | ) | 
					
						
						|  | self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) | 
					
						
						|  |  | 
					
						
						|  | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): | 
					
						
						|  | dtype = self.out_proj.weight.dtype if dtype is None else dtype | 
					
						
						|  | device = self.out_proj.weight.device | 
					
						
						|  | return torch.empty( | 
					
						
						|  | batch_size, | 
					
						
						|  | max_seqlen, | 
					
						
						|  | 2, | 
					
						
						|  | self.num_heads_kv, | 
					
						
						|  | self.head_dim, | 
					
						
						|  | dtype=dtype, | 
					
						
						|  | device=device, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _update_kv_cache(self, kv, inference_params): | 
					
						
						|  | """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" | 
					
						
						|  | assert not self.dwconv, "Generation does not support dwconv yet" | 
					
						
						|  | assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" | 
					
						
						|  | return _update_kv_cache(kv, inference_params, self.layer_idx) | 
					
						
						|  |  | 
					
						
						|  | def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): | 
					
						
						|  | """ | 
					
						
						|  | Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. | 
					
						
						|  | q: (batch_size, seqlen_q, nheads, head_dim) | 
					
						
						|  | kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) | 
					
						
						|  | """ | 
					
						
						|  | assert inference_params is not None and inference_params.seqlen_offset > 0 | 
					
						
						|  | assert self.use_flash_attn | 
					
						
						|  | batch = q.shape[0] | 
					
						
						|  | kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] | 
					
						
						|  | cache_seqlens = ( | 
					
						
						|  | inference_params.lengths_per_sample[:batch] | 
					
						
						|  | if inference_params.lengths_per_sample is not None | 
					
						
						|  | else inference_params.seqlen_offset | 
					
						
						|  | ) | 
					
						
						|  | context = flash_attn_with_kvcache( | 
					
						
						|  | q, | 
					
						
						|  | kv_cache[:, :, 0], | 
					
						
						|  | kv_cache[:, :, 1], | 
					
						
						|  | kv[:, :, 0], | 
					
						
						|  | kv[:, :, 1], | 
					
						
						|  | cache_seqlens=cache_seqlens, | 
					
						
						|  | softmax_scale=self.inner_cross_attn.softmax_scale, | 
					
						
						|  | causal=self.inner_cross_attn.causal, | 
					
						
						|  | rotary_interleaved=False, | 
					
						
						|  | alibi_slopes=None, | 
					
						
						|  | ) | 
					
						
						|  | return context | 
					
						
						|  |  | 
					
						
						|  | def _update_kvcache_attention(self, q, kv, inference_params): | 
					
						
						|  | """Write kv to inference_params, then do attention""" | 
					
						
						|  | if ( | 
					
						
						|  | inference_params.seqlen_offset == 0 | 
					
						
						|  | or flash_attn_with_kvcache is None | 
					
						
						|  | or not self.use_flash_attn | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | kv = self._update_kv_cache(kv, inference_params) | 
					
						
						|  | return self.inner_cross_attn(q, kv) | 
					
						
						|  | else: | 
					
						
						|  | batch = q.shape[0] | 
					
						
						|  | kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] | 
					
						
						|  | cache_seqlens = ( | 
					
						
						|  | inference_params.lengths_per_sample[:batch] | 
					
						
						|  | if inference_params.lengths_per_sample is not None | 
					
						
						|  | else inference_params.seqlen_offset | 
					
						
						|  | ) | 
					
						
						|  | return flash_attn_with_kvcache( | 
					
						
						|  | q, | 
					
						
						|  | kv_cache[:, :, 0], | 
					
						
						|  | kv_cache[:, :, 1], | 
					
						
						|  | kv[:, :, 0], | 
					
						
						|  | kv[:, :, 1], | 
					
						
						|  | cache_seqlens=cache_seqlens, | 
					
						
						|  | softmax_scale=self.inner_cross_attn.softmax_scale, | 
					
						
						|  | causal=self.inner_cross_attn.causal, | 
					
						
						|  | alibi_slopes=None, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | x, | 
					
						
						|  | x_kv=None, | 
					
						
						|  | key_padding_mask=None, | 
					
						
						|  | cu_seqlens=None, | 
					
						
						|  | max_seqlen=None, | 
					
						
						|  | mixer_subset=None, | 
					
						
						|  | inference_params=None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Arguments: | 
					
						
						|  | x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if | 
					
						
						|  | cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total | 
					
						
						|  | is the is the sum of the sequence lengths in the batch. | 
					
						
						|  | x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. | 
					
						
						|  | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | 
					
						
						|  | of the sequences in the batch, used to index into x. Only applicable when using | 
					
						
						|  | FlashAttention. | 
					
						
						|  | max_seqlen: int. Maximum sequence length in the batch. | 
					
						
						|  | key_padding_mask: boolean mask, True means to keep, False means to mask out. | 
					
						
						|  | (batch, seqlen). Only applicable when not using FlashAttention. | 
					
						
						|  | mixer_subset: for cross-attention only. If not None, will take a subset of x | 
					
						
						|  | before applying the query projection. Useful for e.g., ViT where we only care | 
					
						
						|  | about the CLS token in the last layer. | 
					
						
						|  | inference_params: for generation. Adapted from Megatron-LM (and Apex) | 
					
						
						|  | https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 | 
					
						
						|  | """ | 
					
						
						|  | if cu_seqlens is not None: | 
					
						
						|  | assert max_seqlen is not None | 
					
						
						|  | assert key_padding_mask is None | 
					
						
						|  | assert self.use_flash_attn | 
					
						
						|  | assert not self.dwconv | 
					
						
						|  | if key_padding_mask is not None: | 
					
						
						|  | assert cu_seqlens is None | 
					
						
						|  | assert max_seqlen is None | 
					
						
						|  | assert not self.use_flash_attn | 
					
						
						|  | if inference_params is not None: | 
					
						
						|  | assert key_padding_mask is None | 
					
						
						|  | assert cu_seqlens is None and max_seqlen is None | 
					
						
						|  | assert not self.dwconv | 
					
						
						|  |  | 
					
						
						|  | kwargs = ( | 
					
						
						|  | {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} | 
					
						
						|  | if self.use_flash_attn | 
					
						
						|  | else {"key_padding_mask": key_padding_mask, **kwargs} | 
					
						
						|  | ) | 
					
						
						|  | seqlen_offset = ( | 
					
						
						|  | 0 | 
					
						
						|  | if inference_params is None | 
					
						
						|  | else ( | 
					
						
						|  | inference_params.lengths_per_sample | 
					
						
						|  | if inference_params.lengths_per_sample is not None | 
					
						
						|  | else inference_params.seqlen_offset | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | rotary_max_seqlen = ( | 
					
						
						|  | inference_params.max_sequence_len if inference_params is not None else max_seqlen | 
					
						
						|  | ) | 
					
						
						|  | batch, seqlen = x.shape[:2] | 
					
						
						|  | if not self.cross_attn and self.num_heads_kv == self.num_heads: | 
					
						
						|  | assert x_kv is None and mixer_subset is None | 
					
						
						|  | if not self.return_residual: | 
					
						
						|  | qkv = self.Wqkv(x) | 
					
						
						|  | else: | 
					
						
						|  | qkv, x = self.Wqkv(x) | 
					
						
						|  | if self.dwconv: | 
					
						
						|  | qkv = rearrange( | 
					
						
						|  | self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" | 
					
						
						|  | ).contiguous() | 
					
						
						|  | qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) | 
					
						
						|  | if ( | 
					
						
						|  | inference_params is None | 
					
						
						|  | or inference_params.seqlen_offset == 0 | 
					
						
						|  | or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) | 
					
						
						|  | or not self.use_flash_attn | 
					
						
						|  | ): | 
					
						
						|  | if inference_params is None: | 
					
						
						|  | if not self.checkpointing: | 
					
						
						|  | context = self.inner_attn(qkv, **kwargs) | 
					
						
						|  | else: | 
					
						
						|  | context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) | 
					
						
						|  | else: | 
					
						
						|  | context = self._update_kvcache_attention( | 
					
						
						|  | qkv[:, :, 0], qkv[:, :, 1:], inference_params | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | context = self._apply_rotary_update_kvcache_attention( | 
					
						
						|  | qkv[:, :, 0], qkv[:, :, 1:], inference_params | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | if self.cross_attn: | 
					
						
						|  | if not self.return_residual: | 
					
						
						|  | q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) | 
					
						
						|  | kv = self.Wkv(x_kv if x_kv is not None else x) | 
					
						
						|  | else: | 
					
						
						|  | if x_kv is not None: | 
					
						
						|  | kv, x_kv = self.Wkv(x_kv) | 
					
						
						|  | else: | 
					
						
						|  | kv, x = self.Wkv(x) | 
					
						
						|  | q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) | 
					
						
						|  | else: | 
					
						
						|  | assert self.num_heads_kv != self.num_heads | 
					
						
						|  | if not self.return_residual: | 
					
						
						|  | qkv = self.Wqkv(x) | 
					
						
						|  | else: | 
					
						
						|  | qkv, x = self.Wqkv(x) | 
					
						
						|  | q = qkv[..., : self.num_heads * self.head_dim] | 
					
						
						|  | kv = qkv[..., self.num_heads * self.head_dim :] | 
					
						
						|  | q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) | 
					
						
						|  | kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) | 
					
						
						|  | if self.dwconv: | 
					
						
						|  | q = rearrange( | 
					
						
						|  | self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" | 
					
						
						|  | ).contiguous() | 
					
						
						|  | kv = rearrange( | 
					
						
						|  | self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" | 
					
						
						|  | ).contiguous() | 
					
						
						|  | if ( | 
					
						
						|  | inference_params is None | 
					
						
						|  | or inference_params.seqlen_offset == 0 | 
					
						
						|  | or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) | 
					
						
						|  | or not self.use_flash_attn | 
					
						
						|  | ): | 
					
						
						|  | if inference_params is None: | 
					
						
						|  | if not self.checkpointing: | 
					
						
						|  | context = self.inner_cross_attn(q, kv, **kwargs) | 
					
						
						|  | else: | 
					
						
						|  | context = torch.utils.checkpoint.checkpoint( | 
					
						
						|  | self.inner_cross_attn, q, kv, **kwargs | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | context = self._update_kvcache_attention(q, kv, inference_params) | 
					
						
						|  | else: | 
					
						
						|  | context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) | 
					
						
						|  | out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) | 
					
						
						|  | return out if not self.return_residual else (out, x) | 
					
						
						|  |  |