Update modeling_motif.py
Browse filesremove apply_rotary_pos_emb fused rope
- modeling_motif.py +6 -33
modeling_motif.py
CHANGED
@@ -263,7 +263,7 @@ def rotate_half(x):
|
|
263 |
return rotated_tensor
|
264 |
|
265 |
|
266 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1
|
267 |
"""
|
268 |
Applies rotary position embeddings to the input tensors.
|
269 |
|
@@ -274,9 +274,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fus
|
|
274 |
sin (torch.Tensor): Sine values for rotary embedding.
|
275 |
unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
|
276 |
Defaults to 1.
|
277 |
-
fused_rope (bool, optional): If True, applies fused rotary embeddings using
|
278 |
-
`moreh_ops.apply_rotary_emb`. If False, computes rotary embeddings manually.
|
279 |
-
Defaults to False.
|
280 |
|
281 |
Returns:
|
282 |
Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
|
@@ -288,31 +285,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fus
|
|
288 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
289 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
290 |
'''
|
291 |
-
|
292 |
-
|
293 |
-
return map(
|
294 |
lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
|
295 |
(rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
|
296 |
-
else:
|
297 |
-
# (B, NH, S, D_KV) -> (B, S, NH, D_KV)
|
298 |
-
cos = cos[position_ids]
|
299 |
-
sin = sin[position_ids]
|
300 |
-
|
301 |
-
q = q.transpose(1, 2)
|
302 |
-
k = k.transpose(1, 2)
|
303 |
-
|
304 |
-
# Expand 'batch' dim
|
305 |
-
cos = cos.expand(q.shape[0], *cos.shape[1:])
|
306 |
-
sin = sin.expand(q.shape[0], *sin.shape[1:])
|
307 |
-
|
308 |
-
q_embed = moreh_ops.apply_rotary_emb(q, cos, sin, opcode=1)
|
309 |
-
k_embed = moreh_ops.apply_rotary_emb(k, cos, sin, opcode=1)
|
310 |
-
|
311 |
-
# (B, S, NH, D_KV) -> (B, NH, S, D_KV)
|
312 |
-
q_embed = q_embed.transpose(1, 2)
|
313 |
-
k_embed = k_embed.transpose(1, 2)
|
314 |
-
|
315 |
-
return q_embed, k_embed
|
316 |
|
317 |
|
318 |
class MotifMLP(nn.Module):
|
@@ -461,8 +437,7 @@ class MotifAttention(nn.Module):
|
|
461 |
key_states,
|
462 |
cos,
|
463 |
sin,
|
464 |
-
position_ids=position_ids
|
465 |
-
fused_rope=self.config.fused_rope)
|
466 |
|
467 |
if past_key_value is not None:
|
468 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
@@ -609,8 +584,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
609 |
key_states,
|
610 |
cos,
|
611 |
sin,
|
612 |
-
position_ids=position_ids
|
613 |
-
fused_rope=False)
|
614 |
|
615 |
if past_key_value is not None:
|
616 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
@@ -758,8 +732,7 @@ class MotifSdpaAttention(MotifAttention):
|
|
758 |
query_states, key_states = apply_rotary_pos_emb(query_states,
|
759 |
key_states,
|
760 |
cos,
|
761 |
-
sin
|
762 |
-
fused_rope=self.config.fused_rope)
|
763 |
|
764 |
if past_key_value is not None:
|
765 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
|
263 |
return rotated_tensor
|
264 |
|
265 |
|
266 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
267 |
"""
|
268 |
Applies rotary position embeddings to the input tensors.
|
269 |
|
|
|
274 |
sin (torch.Tensor): Sine values for rotary embedding.
|
275 |
unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
|
276 |
Defaults to 1.
|
|
|
|
|
|
|
277 |
|
278 |
Returns:
|
279 |
Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
|
|
|
285 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
286 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
287 |
'''
|
288 |
+
device = q.device
|
289 |
+
return map(
|
|
|
290 |
lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
|
291 |
(rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
|
294 |
class MotifMLP(nn.Module):
|
|
|
437 |
key_states,
|
438 |
cos,
|
439 |
sin,
|
440 |
+
position_ids=position_ids)
|
|
|
441 |
|
442 |
if past_key_value is not None:
|
443 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
|
584 |
key_states,
|
585 |
cos,
|
586 |
sin,
|
587 |
+
position_ids=position_ids)
|
|
|
588 |
|
589 |
if past_key_value is not None:
|
590 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
|
732 |
query_states, key_states = apply_rotary_pos_emb(query_states,
|
733 |
key_states,
|
734 |
cos,
|
735 |
+
sin)
|
|
|
736 |
|
737 |
if past_key_value is not None:
|
738 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|