HV-Khurdula commited on
Commit
4fe8766
·
verified ·
1 Parent(s): d53b116

Update rope.py

Browse files

fix: fallback to apply rotatory_emb

Files changed (1) hide show
  1. rope.py +31 -30
rope.py CHANGED
@@ -16,36 +16,37 @@ def precompute_freqs_cis(
16
  freqs = torch.exp(1j * freqs)
17
  return torch.stack([freqs.real, freqs.imag], dim=-1)
18
 
19
-
20
- def apply_rotary_emb(x, freqs_cis, position_ids, num_heads, rot_dim=None, interleave=False):
21
- B, H, T, D = x.shape
22
- rot_half_from_freqs = freqs_cis.size(-2)
23
- rd = rot_dim or (rot_half_from_freqs * 2)
24
- rd = min(rd, D)
25
-
26
- x_rot, x_pass = x[..., :rd], x[..., rd:]
27
-
28
- # gather cos/sin for these positions
29
- if torch.is_tensor(position_ids):
30
- if position_ids.dim() == 2 and position_ids.size(0) == B:
31
- freq = freqs_cis[position_ids] # (B,T,rot_half,2)
32
- elif position_ids.dim() == 1 and position_ids.size(0) == T:
33
- freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
34
- else: # scalar
35
- pid = position_ids.view(()).long()
36
- freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
 
 
 
 
 
 
 
 
 
37
  else:
38
- pid = torch.tensor(position_ids, device=x.device, dtype=torch.long)
39
- freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
40
-
41
- rot_half = rd // 2
42
- cos = freq[..., 0][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
43
- sin = freq[..., 1][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
44
 
45
- x_rot = x_rot.view(B, H, T, rot_half, 2)
46
- xr, xi = x_rot[..., 0], x_rot[..., 1]
47
- yr = xr * cos - xi * sin
48
- yi = xr * sin + xi * cos
49
- y = torch.stack((yr, yi), dim=-1).flatten(-2)
50
 
51
- return torch.cat([y.to(x.dtype), x_pass], dim=-1)
 
16
  freqs = torch.exp(1j * freqs)
17
  return torch.stack([freqs.real, freqs.imag], dim=-1)
18
 
19
+ def apply_rotary_emb(x_q: torch.Tensor,
20
+ x_k: torch.Tensor,
21
+ cos: torch.Tensor,
22
+ sin: torch.Tensor,
23
+ position_ids: torch.Tensor | None = None):
24
+ """
25
+ Applies RoPE to q, k.
26
+ Keeps shapes: (B, H, T, D) or (B, T, H, D). Cos/sin are cast to x.dtype.
27
+ """
28
+ # Align dtypes to avoid fp16/bf16 fp32 mismatches
29
+ cos = cos.to(x_q.dtype)
30
+ sin = sin.to(x_q.dtype)
31
+
32
+ def _rotate_half(x):
33
+ d = x.shape[-1] // 2
34
+ x1, x2 = x[..., :d], x[..., d:]
35
+ return torch.cat((-x2, x1), dim=-1)
36
+
37
+ if position_ids is not None:
38
+ # (B, T) → index cos/sin per token
39
+ # cos/sin expected as (L, D) or broadcastable to x
40
+ cos_ = cos.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
41
+ sin_ = sin.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
42
+ # reshape to broadcast over heads
43
+ while cos_.dim() < x_q.dim():
44
+ cos_ = cos_.unsqueeze(1)
45
+ sin_ = sin_.unsqueeze(1)
46
  else:
47
+ cos_, sin_ = cos, sin
 
 
 
 
 
48
 
49
+ q = (x_q * cos_) + (_rotate_half(x_q) * sin_)
50
+ k = (x_k * cos_) + (_rotate_half(x_k) * sin_)
51
+ return q, k
 
 
52