leejunhyeok commited on
Commit
9b40539
·
verified ·
1 Parent(s): 607612f

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +4 -26
modeling_motif.py CHANGED
@@ -261,35 +261,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
261
  sin (torch.Tensor): Sine values for rotary embedding.
262
  unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
263
  Defaults to 1.
264
- fused_rope (bool, optional): If True, applies fused rotary embeddings using
265
- `moreh_ops.apply_rotary_emb`. If False, computes rotary embeddings manually.
266
- Defaults to False.
267
  Returns:
268
  Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
269
  """
270
- '''
271
- # (B, NH, S, D_KV) -> (B, S, NH, D_KV)
272
- cos = cos.unsqueeze(unsqueeze_dim)
273
- sin = sin.unsqueeze(unsqueeze_dim)
274
- q_embed = (q * cos) + (rotate_half(q) * sin)
275
- k_embed = (k * cos) + (rotate_half(k) * sin)
276
- '''
277
-
278
- q = q.transpose(1, 2)
279
- k = k.transpose(1, 2)
280
-
281
- # Expand 'batch' dim
282
- cos = cos.expand(q.shape[0], *cos.shape[1:])
283
- sin = sin.expand(q.shape[0], *sin.shape[1:])
284
-
285
- q_embed = moreh_ops.apply_rotary_emb(q, cos, sin, opcode=1)
286
- k_embed = moreh_ops.apply_rotary_emb(k, cos, sin, opcode=1)
287
-
288
- # (B, S, NH, D_KV) -> (B, NH, S, D_KV)
289
- q_embed = q_embed.transpose(1, 2)
290
- k_embed = k_embed.transpose(1, 2)
291
-
292
- return q_embed, k_embed
293
 
294
 
295
  class MotifMLP(nn.Module):
 
261
  sin (torch.Tensor): Sine values for rotary embedding.
262
  unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
263
  Defaults to 1.
 
 
 
264
  Returns:
265
  Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
266
  """
267
+ device = q.device
268
+ return map(
269
+ lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
270
+ (rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
 
273
  class MotifMLP(nn.Module):