HV-Khurdula's picture
Update rope.py
f542ccb verified
raw
history blame
2 kB
# Ethically sourced from https://github.com/xjdr-alt/entropix
import torch
def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
use_scaled: bool = False,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
t = torch.arange(end, dtype=dtype).unsqueeze(1)
freqs = t * freqs.unsqueeze(0)
freqs = torch.exp(1j * freqs)
return torch.stack([freqs.real, freqs.imag], dim=-1)
def apply_rotary_emb(x, freqs_cis, position_ids, num_heads, rot_dim=None, interleave=False):
B, H, T, D = x.shape
rot_half_from_freqs = freqs_cis.size(-2)
rd = rot_dim or (rot_half_from_freqs * 2)
rd = min(rd, D)
x_rot, x_pass = x[..., :rd], x[..., rd:]
# gather cos/sin for these positions
if torch.is_tensor(position_ids):
if position_ids.dim() == 2 and position_ids.size(0) == B:
freq = freqs_cis[position_ids] # (B,T,rot_half,2)
elif position_ids.dim() == 1 and position_ids.size(0) == T:
freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
else: # scalar
pid = position_ids.view(()).long()
freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
else:
pid = torch.tensor(position_ids, device=x.device, dtype=torch.long)
freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
rot_half = rd // 2
cos = freq[..., 0][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
sin = freq[..., 1][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
x_rot = x_rot.view(B, H, T, rot_half, 2)
xr, xi = x_rot[..., 0], x_rot[..., 1]
yr = xr * cos - xi * sin
yi = xr * sin + xi * cos
y = torch.stack((yr, yi), dim=-1).flatten(-2)
return torch.cat([y.to(x.dtype), x_pass], dim=-1)