Update rope.py
Browse files
rope.py
CHANGED
|
@@ -18,43 +18,34 @@ def precompute_freqs_cis(
|
|
| 18 |
|
| 19 |
|
| 20 |
def apply_rotary_emb(x, freqs_cis, position_ids, num_heads, rot_dim=None, interleave=False):
|
| 21 |
-
"""
|
| 22 |
-
x: (B, num_heads, q_len, head_dim)
|
| 23 |
-
freqs_cis: (max_seq, rot_half, 2) # [..., cos/sin]
|
| 24 |
-
position_ids: (q_len,) or (B, q_len) or scalar
|
| 25 |
-
num_heads: int (unused here; kept for API compatibility)
|
| 26 |
-
rot_dim: optional; if None we use min(D, 2*rot_half)
|
| 27 |
-
"""
|
| 28 |
B, H, T, D = x.shape
|
| 29 |
-
rot_half_from_freqs = freqs_cis.size(-2)
|
| 30 |
rd = rot_dim or (rot_half_from_freqs * 2)
|
| 31 |
-
rd = min(rd, D)
|
| 32 |
|
| 33 |
-
x_rot, x_pass = x[..., :rd], x[..., rd:]
|
| 34 |
|
| 35 |
-
#
|
| 36 |
if torch.is_tensor(position_ids):
|
| 37 |
-
if position_ids.dim() == 2 and position_ids.size(0) == B:
|
| 38 |
-
freq = freqs_cis[position_ids]
|
| 39 |
-
elif position_ids.dim() == 1 and position_ids.size(0) == T:
|
| 40 |
freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
|
| 41 |
-
else:
|
| 42 |
pid = position_ids.view(()).long()
|
| 43 |
freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
|
| 44 |
else:
|
| 45 |
pid = torch.tensor(position_ids, device=x.device, dtype=torch.long)
|
| 46 |
freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
|
| 47 |
|
| 48 |
-
# Trim freqs to rd//2 if needed
|
| 49 |
rot_half = rd // 2
|
| 50 |
-
cos = freq[..., 0][..., :rot_half].unsqueeze(1)
|
| 51 |
-
sin = freq[..., 1][..., :rot_half].unsqueeze(1)
|
| 52 |
|
| 53 |
-
# Split real/imag and apply rotation
|
| 54 |
x_rot = x_rot.view(B, H, T, rot_half, 2)
|
| 55 |
-
xr, xi = x_rot[..., 0], x_rot[..., 1]
|
| 56 |
yr = xr * cos - xi * sin
|
| 57 |
yi = xr * sin + xi * cos
|
| 58 |
-
y = torch.stack((yr, yi), dim=-1).flatten(-2)
|
| 59 |
|
| 60 |
-
return torch.cat([y.to(x.dtype), x_pass], dim=-1)
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def apply_rotary_emb(x, freqs_cis, position_ids, num_heads, rot_dim=None, interleave=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
B, H, T, D = x.shape
|
| 22 |
+
rot_half_from_freqs = freqs_cis.size(-2)
|
| 23 |
rd = rot_dim or (rot_half_from_freqs * 2)
|
| 24 |
+
rd = min(rd, D)
|
| 25 |
|
| 26 |
+
x_rot, x_pass = x[..., :rd], x[..., rd:]
|
| 27 |
|
| 28 |
+
# gather cos/sin for these positions
|
| 29 |
if torch.is_tensor(position_ids):
|
| 30 |
+
if position_ids.dim() == 2 and position_ids.size(0) == B:
|
| 31 |
+
freq = freqs_cis[position_ids] # (B,T,rot_half,2)
|
| 32 |
+
elif position_ids.dim() == 1 and position_ids.size(0) == T:
|
| 33 |
freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
|
| 34 |
+
else: # scalar
|
| 35 |
pid = position_ids.view(()).long()
|
| 36 |
freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
|
| 37 |
else:
|
| 38 |
pid = torch.tensor(position_ids, device=x.device, dtype=torch.long)
|
| 39 |
freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
|
| 40 |
|
|
|
|
| 41 |
rot_half = rd // 2
|
| 42 |
+
cos = freq[..., 0][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
|
| 43 |
+
sin = freq[..., 1][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
|
| 44 |
|
|
|
|
| 45 |
x_rot = x_rot.view(B, H, T, rot_half, 2)
|
| 46 |
+
xr, xi = x_rot[..., 0], x_rot[..., 1]
|
| 47 |
yr = xr * cos - xi * sin
|
| 48 |
yi = xr * sin + xi * cos
|
| 49 |
+
y = torch.stack((yr, yi), dim=-1).flatten(-2)
|
| 50 |
|
| 51 |
+
return torch.cat([y.to(x.dtype), x_pass], dim=-1)
|