Update rope.py
Browse filesfix: apply rotatory embed.
rope.py
CHANGED
@@ -16,37 +16,54 @@ def precompute_freqs_cis(
|
|
16 |
freqs = torch.exp(1j * freqs)
|
17 |
return torch.stack([freqs.real, freqs.imag], dim=-1)
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
"""
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
27 |
"""
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
# cos/sin expected as (L, D) or broadcastable to x
|
40 |
-
cos_ = cos.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
|
41 |
-
sin_ = sin.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
|
42 |
-
# reshape to broadcast over heads
|
43 |
-
while cos_.dim() < x_q.dim():
|
44 |
-
cos_ = cos_.unsqueeze(1)
|
45 |
-
sin_ = sin_.unsqueeze(1)
|
46 |
else:
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
q = (x_q * cos_) + (_rotate_half(x_q) * sin_)
|
50 |
-
k = (x_k * cos_) + (_rotate_half(x_k) * sin_)
|
51 |
-
return q, k
|
52 |
|
|
|
16 |
freqs = torch.exp(1j * freqs)
|
17 |
return torch.stack([freqs.real, freqs.imag], dim=-1)
|
18 |
|
19 |
+
# rope.py
|
20 |
+
import torch
|
21 |
+
|
22 |
+
def apply_rotary_emb(
|
23 |
+
x: torch.Tensor,
|
24 |
+
freqs_cis: torch.Tensor,
|
25 |
+
position_ids: torch.Tensor,
|
26 |
+
num_heads: int,
|
27 |
+
rot_dim: int = 32,
|
28 |
+
interleave: bool = False,
|
29 |
+
) -> torch.Tensor:
|
30 |
"""
|
31 |
+
RoPE as used in the original moondream2 text stack:
|
32 |
+
x: (B, H, T, D)
|
33 |
+
freqs_cis: (max_T, rot_dim//2, 2) where [...,0]=cos, [...,1]=sin
|
34 |
+
position_ids: (T,) or (B,T)
|
35 |
+
returns x with first rot_dim dims rotated.
|
36 |
"""
|
37 |
+
assert rot_dim == freqs_cis.shape[-2] * 2
|
38 |
+
assert num_heads == x.shape[1]
|
39 |
+
|
40 |
+
B, H, T, D = x.shape
|
41 |
+
rd = min(rot_dim, D)
|
42 |
+
x_rot, x_pass = x[..., :rd], x[..., rd:]
|
43 |
+
|
44 |
+
# split real/imag parts depending on layout
|
45 |
+
if interleave:
|
46 |
+
xr = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
|
47 |
+
xi = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
else:
|
49 |
+
d = x_rot.shape[-1] // 2
|
50 |
+
xr, xi = x_rot[..., :d], x_rot[..., d:]
|
51 |
+
|
52 |
+
# gather cos/sin for these positions
|
53 |
+
if position_ids.dim() == 2 and position_ids.size(0) == B:
|
54 |
+
freq = freqs_cis[position_ids] # (B, T, rd//2, 2)
|
55 |
+
else: # (T,) or scalar
|
56 |
+
freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
|
57 |
+
|
58 |
+
rot_half = rd // 2
|
59 |
+
cos = freq[..., 0][..., :rot_half].unsqueeze(1).to(x.dtype) # (B,1,T,rot_half)
|
60 |
+
sin = freq[..., 1][..., :rot_half].unsqueeze(1).to(x.dtype)
|
61 |
+
|
62 |
+
# complex multiply
|
63 |
+
yr = xr * cos - xi * sin
|
64 |
+
yi = xr * sin + xi * cos
|
65 |
+
y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd)
|
66 |
+
|
67 |
+
return torch.cat([y, x_pass], dim=-1)
|
68 |
|
|
|
|
|
|
|
69 |
|