Update modeling_motif.py
Browse files- modeling_motif.py +4 -26
modeling_motif.py
CHANGED
@@ -261,35 +261,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
261 |
sin (torch.Tensor): Sine values for rotary embedding.
|
262 |
unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
|
263 |
Defaults to 1.
|
264 |
-
fused_rope (bool, optional): If True, applies fused rotary embeddings using
|
265 |
-
`moreh_ops.apply_rotary_emb`. If False, computes rotary embeddings manually.
|
266 |
-
Defaults to False.
|
267 |
Returns:
|
268 |
Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
|
269 |
"""
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
275 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
276 |
-
'''
|
277 |
-
|
278 |
-
q = q.transpose(1, 2)
|
279 |
-
k = k.transpose(1, 2)
|
280 |
-
|
281 |
-
# Expand 'batch' dim
|
282 |
-
cos = cos.expand(q.shape[0], *cos.shape[1:])
|
283 |
-
sin = sin.expand(q.shape[0], *sin.shape[1:])
|
284 |
-
|
285 |
-
q_embed = moreh_ops.apply_rotary_emb(q, cos, sin, opcode=1)
|
286 |
-
k_embed = moreh_ops.apply_rotary_emb(k, cos, sin, opcode=1)
|
287 |
-
|
288 |
-
# (B, S, NH, D_KV) -> (B, NH, S, D_KV)
|
289 |
-
q_embed = q_embed.transpose(1, 2)
|
290 |
-
k_embed = k_embed.transpose(1, 2)
|
291 |
-
|
292 |
-
return q_embed, k_embed
|
293 |
|
294 |
|
295 |
class MotifMLP(nn.Module):
|
|
|
261 |
sin (torch.Tensor): Sine values for rotary embedding.
|
262 |
unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
|
263 |
Defaults to 1.
|
|
|
|
|
|
|
264 |
Returns:
|
265 |
Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
|
266 |
"""
|
267 |
+
device = q.device
|
268 |
+
return map(
|
269 |
+
lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
|
270 |
+
(rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
|
273 |
class MotifMLP(nn.Module):
|