HV-Khurdula commited on
Commit
f542ccb
·
verified ·
1 Parent(s): 4d9e33f

Update rope.py

Browse files
Files changed (1) hide show
  1. rope.py +13 -22
rope.py CHANGED
@@ -18,43 +18,34 @@ def precompute_freqs_cis(
18
 
19
 
20
  def apply_rotary_emb(x, freqs_cis, position_ids, num_heads, rot_dim=None, interleave=False):
21
- """
22
- x: (B, num_heads, q_len, head_dim)
23
- freqs_cis: (max_seq, rot_half, 2) # [..., cos/sin]
24
- position_ids: (q_len,) or (B, q_len) or scalar
25
- num_heads: int (unused here; kept for API compatibility)
26
- rot_dim: optional; if None we use min(D, 2*rot_half)
27
- """
28
  B, H, T, D = x.shape
29
- rot_half_from_freqs = freqs_cis.size(-2) # available rotary half
30
  rd = rot_dim or (rot_half_from_freqs * 2)
31
- rd = min(rd, D) # don't exceed head_dim
32
 
33
- x_rot, x_pass = x[..., :rd], x[..., rd:] # (B,H,T,rd), (B,H,T,D-rd)
34
 
35
- # Gather cos/sin for each position; result (B,T,rot_half_from_freqs)
36
  if torch.is_tensor(position_ids):
37
- if position_ids.dim() == 2 and position_ids.size(0) == B: # (B,T)
38
- freq = freqs_cis[position_ids] # (B,T,rot_half,2)
39
- elif position_ids.dim() == 1 and position_ids.size(0) == T: # (T,)
40
  freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
41
- else: # scalar
42
  pid = position_ids.view(()).long()
43
  freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
44
  else:
45
  pid = torch.tensor(position_ids, device=x.device, dtype=torch.long)
46
  freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
47
 
48
- # Trim freqs to rd//2 if needed
49
  rot_half = rd // 2
50
- cos = freq[..., 0][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
51
- sin = freq[..., 1][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
52
 
53
- # Split real/imag and apply rotation
54
  x_rot = x_rot.view(B, H, T, rot_half, 2)
55
- xr, xi = x_rot[..., 0], x_rot[..., 1] # (B,H,T,rot_half)
56
  yr = xr * cos - xi * sin
57
  yi = xr * sin + xi * cos
58
- y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd)
59
 
60
- return torch.cat([y.to(x.dtype), x_pass], dim=-1)
 
18
 
19
 
20
  def apply_rotary_emb(x, freqs_cis, position_ids, num_heads, rot_dim=None, interleave=False):
 
 
 
 
 
 
 
21
  B, H, T, D = x.shape
22
+ rot_half_from_freqs = freqs_cis.size(-2)
23
  rd = rot_dim or (rot_half_from_freqs * 2)
24
+ rd = min(rd, D)
25
 
26
+ x_rot, x_pass = x[..., :rd], x[..., rd:]
27
 
28
+ # gather cos/sin for these positions
29
  if torch.is_tensor(position_ids):
30
+ if position_ids.dim() == 2 and position_ids.size(0) == B:
31
+ freq = freqs_cis[position_ids] # (B,T,rot_half,2)
32
+ elif position_ids.dim() == 1 and position_ids.size(0) == T:
33
  freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
34
+ else: # scalar
35
  pid = position_ids.view(()).long()
36
  freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
37
  else:
38
  pid = torch.tensor(position_ids, device=x.device, dtype=torch.long)
39
  freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
40
 
 
41
  rot_half = rd // 2
42
+ cos = freq[..., 0][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
43
+ sin = freq[..., 1][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
44
 
 
45
  x_rot = x_rot.view(B, H, T, rot_half, 2)
46
+ xr, xi = x_rot[..., 0], x_rot[..., 1]
47
  yr = xr * cos - xi * sin
48
  yi = xr * sin + xi * cos
49
+ y = torch.stack((yr, yi), dim=-1).flatten(-2)
50
 
51
+ return torch.cat([y.to(x.dtype), x_pass], dim=-1)