HV-Khurdula commited on
Commit
09597f3
·
verified ·
1 Parent(s): 4fe8766

Update rope.py

Browse files

fix: apply rotatory embed.

Files changed (1) hide show
  1. rope.py +46 -29
rope.py CHANGED
@@ -16,37 +16,54 @@ def precompute_freqs_cis(
16
  freqs = torch.exp(1j * freqs)
17
  return torch.stack([freqs.real, freqs.imag], dim=-1)
18
 
19
- def apply_rotary_emb(x_q: torch.Tensor,
20
- x_k: torch.Tensor,
21
- cos: torch.Tensor,
22
- sin: torch.Tensor,
23
- position_ids: torch.Tensor | None = None):
 
 
 
 
 
 
24
  """
25
- Applies RoPE to q, k.
26
- Keeps shapes: (B, H, T, D) or (B, T, H, D). Cos/sin are cast to x.dtype.
 
 
 
27
  """
28
- # Align dtypes to avoid fp16/bf16 → fp32 mismatches
29
- cos = cos.to(x_q.dtype)
30
- sin = sin.to(x_q.dtype)
31
-
32
- def _rotate_half(x):
33
- d = x.shape[-1] // 2
34
- x1, x2 = x[..., :d], x[..., d:]
35
- return torch.cat((-x2, x1), dim=-1)
36
-
37
- if position_ids is not None:
38
- # (B, T) → index cos/sin per token
39
- # cos/sin expected as (L, D) or broadcastable to x
40
- cos_ = cos.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
41
- sin_ = sin.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
42
- # reshape to broadcast over heads
43
- while cos_.dim() < x_q.dim():
44
- cos_ = cos_.unsqueeze(1)
45
- sin_ = sin_.unsqueeze(1)
46
  else:
47
- cos_, sin_ = cos, sin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- q = (x_q * cos_) + (_rotate_half(x_q) * sin_)
50
- k = (x_k * cos_) + (_rotate_half(x_k) * sin_)
51
- return q, k
52
 
 
16
  freqs = torch.exp(1j * freqs)
17
  return torch.stack([freqs.real, freqs.imag], dim=-1)
18
 
19
+ # rope.py
20
+ import torch
21
+
22
+ def apply_rotary_emb(
23
+ x: torch.Tensor,
24
+ freqs_cis: torch.Tensor,
25
+ position_ids: torch.Tensor,
26
+ num_heads: int,
27
+ rot_dim: int = 32,
28
+ interleave: bool = False,
29
+ ) -> torch.Tensor:
30
  """
31
+ RoPE as used in the original moondream2 text stack:
32
+ x: (B, H, T, D)
33
+ freqs_cis: (max_T, rot_dim//2, 2) where [...,0]=cos, [...,1]=sin
34
+ position_ids: (T,) or (B,T)
35
+ returns x with first rot_dim dims rotated.
36
  """
37
+ assert rot_dim == freqs_cis.shape[-2] * 2
38
+ assert num_heads == x.shape[1]
39
+
40
+ B, H, T, D = x.shape
41
+ rd = min(rot_dim, D)
42
+ x_rot, x_pass = x[..., :rd], x[..., rd:]
43
+
44
+ # split real/imag parts depending on layout
45
+ if interleave:
46
+ xr = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
47
+ xi = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
 
 
 
 
 
 
 
48
  else:
49
+ d = x_rot.shape[-1] // 2
50
+ xr, xi = x_rot[..., :d], x_rot[..., d:]
51
+
52
+ # gather cos/sin for these positions
53
+ if position_ids.dim() == 2 and position_ids.size(0) == B:
54
+ freq = freqs_cis[position_ids] # (B, T, rd//2, 2)
55
+ else: # (T,) or scalar
56
+ freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
57
+
58
+ rot_half = rd // 2
59
+ cos = freq[..., 0][..., :rot_half].unsqueeze(1).to(x.dtype) # (B,1,T,rot_half)
60
+ sin = freq[..., 1][..., :rot_half].unsqueeze(1).to(x.dtype)
61
+
62
+ # complex multiply
63
+ yr = xr * cos - xi * sin
64
+ yi = xr * sin + xi * cos
65
+ y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd)
66
+
67
+ return torch.cat([y, x_pass], dim=-1)
68
 
 
 
 
69