Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LlamaRotaryEmbedding(nn.Module): | |
def __init__( | |
self, | |
dim: int = 64, # Dimension per attention head | |
max_seq_len: int = 2048, # Maximum sequence length | |
base: int = 10000, # Base for the angle calculations | |
device: str = None, | |
): | |
super().__init__() | |
self.dim = dim | |
self.max_seq_len = max_seq_len | |
self.base = base | |
# Create cache for position frequencies | |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
# Create position sequence | |
self._seq_len_cached = 0 | |
self._cos_cached = None | |
self._sin_cached = None | |
def _update_cos_sin_tables(self, x: torch.Tensor, seq_len: int): | |
# Return early if cache is valid | |
if seq_len <= self._seq_len_cached: | |
return | |
# Update cache size | |
self._seq_len_cached = seq_len | |
# Create position sequence | |
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
# Calculate position frequencies | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
# Calculate embeddings | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self._cos_cached = emb.cos() # [None, None, :, :] | |
self._sin_cached = emb.sin() # [None, None, :, :] | |
def forward( | |
self, q: torch.Tensor, k: torch.Tensor | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
batch, num_heads, seq_len, head_dim = q.shape | |
# Update cos/sin tables if needed | |
self._update_cos_sin_tables(q, seq_len) | |
# Get cos and sin for current sequence | |
cos = ( | |
self._cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(0) | |
) # Shape: [1, 1, seq_len, dim] | |
sin = ( | |
self._sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(0) | |
) # Shape: [1, 1, seq_len, dim] | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
# Apply rotary embeddings to q and k | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
""" | |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
""" | |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, None, :, :].expand( | |
batch, num_key_value_heads, n_rep, slen, head_dim | |
) | |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |