eunhwanpark-motiftech commited on
Commit
a162402
·
verified ·
1 Parent(s): 6187f77

Update modeling_motif.py

Browse files

remove apply_rotary_pos_emb fused rope

Files changed (1) hide show
  1. modeling_motif.py +6 -33
modeling_motif.py CHANGED
@@ -263,7 +263,7 @@ def rotate_half(x):
263
  return rotated_tensor
264
 
265
 
266
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fused_rope=False):
267
  """
268
  Applies rotary position embeddings to the input tensors.
269
 
@@ -274,9 +274,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fus
274
  sin (torch.Tensor): Sine values for rotary embedding.
275
  unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
276
  Defaults to 1.
277
- fused_rope (bool, optional): If True, applies fused rotary embeddings using
278
- `moreh_ops.apply_rotary_emb`. If False, computes rotary embeddings manually.
279
- Defaults to False.
280
 
281
  Returns:
282
  Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
@@ -288,31 +285,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fus
288
  q_embed = (q * cos) + (rotate_half(q) * sin)
289
  k_embed = (k * cos) + (rotate_half(k) * sin)
290
  '''
291
- if not fused_rope:
292
- device = q.device
293
- return map(
294
  lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
295
  (rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
296
- else:
297
- # (B, NH, S, D_KV) -> (B, S, NH, D_KV)
298
- cos = cos[position_ids]
299
- sin = sin[position_ids]
300
-
301
- q = q.transpose(1, 2)
302
- k = k.transpose(1, 2)
303
-
304
- # Expand 'batch' dim
305
- cos = cos.expand(q.shape[0], *cos.shape[1:])
306
- sin = sin.expand(q.shape[0], *sin.shape[1:])
307
-
308
- q_embed = moreh_ops.apply_rotary_emb(q, cos, sin, opcode=1)
309
- k_embed = moreh_ops.apply_rotary_emb(k, cos, sin, opcode=1)
310
-
311
- # (B, S, NH, D_KV) -> (B, NH, S, D_KV)
312
- q_embed = q_embed.transpose(1, 2)
313
- k_embed = k_embed.transpose(1, 2)
314
-
315
- return q_embed, k_embed
316
 
317
 
318
  class MotifMLP(nn.Module):
@@ -461,8 +437,7 @@ class MotifAttention(nn.Module):
461
  key_states,
462
  cos,
463
  sin,
464
- position_ids=position_ids,
465
- fused_rope=self.config.fused_rope)
466
 
467
  if past_key_value is not None:
468
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -609,8 +584,7 @@ class MotifFlashAttention2(MotifAttention):
609
  key_states,
610
  cos,
611
  sin,
612
- position_ids=position_ids,
613
- fused_rope=False)
614
 
615
  if past_key_value is not None:
616
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -758,8 +732,7 @@ class MotifSdpaAttention(MotifAttention):
758
  query_states, key_states = apply_rotary_pos_emb(query_states,
759
  key_states,
760
  cos,
761
- sin,
762
- fused_rope=self.config.fused_rope)
763
 
764
  if past_key_value is not None:
765
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
 
263
  return rotated_tensor
264
 
265
 
266
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
267
  """
268
  Applies rotary position embeddings to the input tensors.
269
 
 
274
  sin (torch.Tensor): Sine values for rotary embedding.
275
  unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
276
  Defaults to 1.
 
 
 
277
 
278
  Returns:
279
  Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
 
285
  q_embed = (q * cos) + (rotate_half(q) * sin)
286
  k_embed = (k * cos) + (rotate_half(k) * sin)
287
  '''
288
+ device = q.device
289
+ return map(
 
290
  lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
291
  (rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
 
294
  class MotifMLP(nn.Module):
 
437
  key_states,
438
  cos,
439
  sin,
440
+ position_ids=position_ids)
 
441
 
442
  if past_key_value is not None:
443
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
 
584
  key_states,
585
  cos,
586
  sin,
587
+ position_ids=position_ids)
 
588
 
589
  if past_key_value is not None:
590
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
 
732
  query_states, key_states = apply_rotary_pos_emb(query_states,
733
  key_states,
734
  cos,
735
+ sin)
 
736
 
737
  if past_key_value is not None:
738
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models