Update rope.py
Browse filesfix: fallback to apply rotatory_emb
rope.py
CHANGED
|
@@ -16,36 +16,37 @@ def precompute_freqs_cis(
|
|
| 16 |
freqs = torch.exp(1j * freqs)
|
| 17 |
return torch.stack([freqs.real, freqs.imag], dim=-1)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
#
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
else:
|
| 38 |
-
|
| 39 |
-
freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
|
| 40 |
-
|
| 41 |
-
rot_half = rd // 2
|
| 42 |
-
cos = freq[..., 0][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
|
| 43 |
-
sin = freq[..., 1][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
yi = xr * sin + xi * cos
|
| 49 |
-
y = torch.stack((yr, yi), dim=-1).flatten(-2)
|
| 50 |
|
| 51 |
-
return torch.cat([y.to(x.dtype), x_pass], dim=-1)
|
|
|
|
| 16 |
freqs = torch.exp(1j * freqs)
|
| 17 |
return torch.stack([freqs.real, freqs.imag], dim=-1)
|
| 18 |
|
| 19 |
+
def apply_rotary_emb(x_q: torch.Tensor,
|
| 20 |
+
x_k: torch.Tensor,
|
| 21 |
+
cos: torch.Tensor,
|
| 22 |
+
sin: torch.Tensor,
|
| 23 |
+
position_ids: torch.Tensor | None = None):
|
| 24 |
+
"""
|
| 25 |
+
Applies RoPE to q, k.
|
| 26 |
+
Keeps shapes: (B, H, T, D) or (B, T, H, D). Cos/sin are cast to x.dtype.
|
| 27 |
+
"""
|
| 28 |
+
# Align dtypes to avoid fp16/bf16 → fp32 mismatches
|
| 29 |
+
cos = cos.to(x_q.dtype)
|
| 30 |
+
sin = sin.to(x_q.dtype)
|
| 31 |
+
|
| 32 |
+
def _rotate_half(x):
|
| 33 |
+
d = x.shape[-1] // 2
|
| 34 |
+
x1, x2 = x[..., :d], x[..., d:]
|
| 35 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 36 |
+
|
| 37 |
+
if position_ids is not None:
|
| 38 |
+
# (B, T) → index cos/sin per token
|
| 39 |
+
# cos/sin expected as (L, D) or broadcastable to x
|
| 40 |
+
cos_ = cos.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
|
| 41 |
+
sin_ = sin.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
|
| 42 |
+
# reshape to broadcast over heads
|
| 43 |
+
while cos_.dim() < x_q.dim():
|
| 44 |
+
cos_ = cos_.unsqueeze(1)
|
| 45 |
+
sin_ = sin_.unsqueeze(1)
|
| 46 |
else:
|
| 47 |
+
cos_, sin_ = cos, sin
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
q = (x_q * cos_) + (_rotate_half(x_q) * sin_)
|
| 50 |
+
k = (x_k * cos_) + (_rotate_half(x_k) * sin_)
|
| 51 |
+
return q, k
|
|
|
|
|
|
|
| 52 |
|
|
|