Update modeling_motif.py
Browse files- modeling_motif.py +6 -5
modeling_motif.py
CHANGED
@@ -455,11 +455,12 @@ class MotifAttention(nn.Module):
|
|
455 |
cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
|
456 |
if use_cache else position_embeddings)
|
457 |
|
458 |
-
query_states, key_states = apply_rotary_pos_emb(
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
|
|
463 |
|
464 |
if past_key_value is not None:
|
465 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
|
455 |
cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
|
456 |
if use_cache else position_embeddings)
|
457 |
|
458 |
+
query_states, key_states = apply_rotary_pos_emb(
|
459 |
+
query_states,
|
460 |
+
key_states,
|
461 |
+
cos,
|
462 |
+
sin
|
463 |
+
)
|
464 |
|
465 |
if past_key_value is not None:
|
466 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|