Update modeling_motif.py
Browse files- modeling_motif.py +3 -121
modeling_motif.py
CHANGED
@@ -1040,13 +1040,12 @@ class MotifModel(MotifPreTrainedModel):
|
|
1040 |
super().__init__(config)
|
1041 |
self.padding_idx = config.pad_token_id
|
1042 |
self.vocab_size = config.vocab_size
|
1043 |
-
self.multi_token_heads = config.multi_token_heads
|
1044 |
|
1045 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1046 |
# NOTE: For multi-token models, the last decoder layers (one for each token index)
|
1047 |
# are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
|
1048 |
|
1049 |
-
num_hidden_layers = config.num_hidden_layers
|
1050 |
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
1051 |
self._attn_implementation = config._attn_implementation
|
1052 |
RMSNorm = MorehRMSNorm
|
@@ -1338,16 +1337,8 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
1338 |
super().__init__(config)
|
1339 |
self.model = MotifModel(config)
|
1340 |
self.vocab_size = config.vocab_size
|
1341 |
-
self.multi_token_heads = config.multi_token_heads
|
1342 |
|
1343 |
-
|
1344 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1345 |
-
else:
|
1346 |
-
self.tokenwise_last_layers = nn.ModuleList(
|
1347 |
-
[MotifDecoderLayer(config, config.num_hidden_layers - 1) for _ in range(self.multi_token_heads)])
|
1348 |
-
self.tokenwise_lm_heads = nn.ModuleList(
|
1349 |
-
[nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(self.multi_token_heads)])
|
1350 |
-
self.should_skip_separate_backward_pass = self.multi_token_heads is not None
|
1351 |
|
1352 |
# Initialize weights and apply final processing
|
1353 |
self.post_init()
|
@@ -1374,101 +1365,7 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
1374 |
def get_decoder(self):
|
1375 |
return self.model
|
1376 |
|
1377 |
-
|
1378 |
-
hidden_states: torch.FloatTensor,
|
1379 |
-
outputs: MotifModelOutputWithPast,
|
1380 |
-
labels: torch.LongTensor,
|
1381 |
-
position_ids: Optional[torch.LongTensor],
|
1382 |
-
output_attentions: Optional[bool],
|
1383 |
-
use_cache: Optional[bool],
|
1384 |
-
cache_position: Optional[torch.LongTensor],
|
1385 |
-
return_dict: Optional[bool],
|
1386 |
-
num_logits_to_keep: int = 0) -> CausalLMOutputWithPast:
|
1387 |
-
"""
|
1388 |
-
This implements the main forward-backward procedure for multi-token model training proposed in
|
1389 |
-
the paper https://arxiv.org/abs/2404.19737.
|
1390 |
-
Essentially,
|
1391 |
-
- The multi-token model tries to predict n (instead of 1) tokens at a time.
|
1392 |
-
- Applying this only during training and using first-token prediction during inference is still helpful.
|
1393 |
-
- The change in architecture: when using n-token prediction, each token index (between 1 and n) has its own
|
1394 |
-
(1) last attention layer and (2) lm head.
|
1395 |
-
- The change in loss: sum of cross-entropy losses corresponding to each token index.
|
1396 |
-
- Custom forward-backward procedure for memory efficiency: refer to the implementation of `multi_head_forward_backward`.
|
1397 |
-
"""
|
1398 |
-
if not return_dict:
|
1399 |
-
raise NotImplementedError("return_dict must be True for multi-token training")
|
1400 |
-
|
1401 |
-
past_key_values = outputs.past_key_values
|
1402 |
-
causal_mask = outputs.causal_mask
|
1403 |
-
position_embeddings = outputs.position_embeddings
|
1404 |
-
|
1405 |
-
if labels is not None:
|
1406 |
-
labels = labels.to(hidden_states.device)
|
1407 |
-
|
1408 |
-
def _tokenwise_forward(hidden_states: torch.Tensor, token_idx):
|
1409 |
-
## Model forward
|
1410 |
-
layer = self.tokenwise_last_layers[token_idx]
|
1411 |
-
lm_head = self.tokenwise_lm_heads[token_idx]
|
1412 |
-
|
1413 |
-
layer_outputs = layer(
|
1414 |
-
hidden_states,
|
1415 |
-
attention_mask=causal_mask,
|
1416 |
-
position_ids=position_ids,
|
1417 |
-
past_key_values=past_key_values, # TODO: update past_key_values?
|
1418 |
-
output_attentions=output_attentions,
|
1419 |
-
use_cache=use_cache,
|
1420 |
-
cache_position=cache_position,
|
1421 |
-
position_embeddings=position_embeddings,
|
1422 |
-
)
|
1423 |
-
last_hidden_states = layer_outputs[0]
|
1424 |
-
if num_logits_to_keep > 0:
|
1425 |
-
assert labels is None
|
1426 |
-
last_hidden_states = last_hidden_states[:, -num_logits_to_keep:, :]
|
1427 |
-
tokenwise_logits = lm_head(last_hidden_states)
|
1428 |
-
|
1429 |
-
if labels is None:
|
1430 |
-
return {
|
1431 |
-
"loss": None,
|
1432 |
-
"logits": tokenwise_logits,
|
1433 |
-
}
|
1434 |
-
|
1435 |
-
## Compute loss
|
1436 |
-
shift_n = token_idx + 1
|
1437 |
-
shift_logits = tokenwise_logits[..., :-shift_n, :].contiguous()
|
1438 |
-
shift_labels = labels[..., shift_n:].contiguous()
|
1439 |
-
|
1440 |
-
loss_fct = CrossEntropyLoss()
|
1441 |
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1442 |
-
shift_labels = shift_labels.view(-1)
|
1443 |
-
|
1444 |
-
tokenwise_loss = loss_fct(shift_logits, shift_labels)
|
1445 |
-
|
1446 |
-
return {
|
1447 |
-
"loss": tokenwise_loss,
|
1448 |
-
"logits": tokenwise_logits,
|
1449 |
-
}
|
1450 |
-
|
1451 |
-
head_fns = [
|
1452 |
-
lambda hidden_states, token_idx=token_idx: _tokenwise_forward(hidden_states, token_idx)
|
1453 |
-
for token_idx in range(self.multi_token_heads)
|
1454 |
-
]
|
1455 |
-
loss, logits = multi_head_forward_backward(hidden_states,
|
1456 |
-
head_fns,
|
1457 |
-
return_keys=("loss", "logits"),
|
1458 |
-
return_only_first_head=True)
|
1459 |
-
|
1460 |
-
if not return_dict:
|
1461 |
-
output = (logits, ) + outputs[1:]
|
1462 |
-
return (loss, ) + output
|
1463 |
-
|
1464 |
-
return CausalLMOutputWithPast(
|
1465 |
-
loss=loss,
|
1466 |
-
logits=logits,
|
1467 |
-
past_key_values=outputs.past_key_values,
|
1468 |
-
hidden_states=outputs.hidden_states,
|
1469 |
-
attentions=outputs.attentions,
|
1470 |
-
)
|
1471 |
-
|
1472 |
@add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
|
1473 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1474 |
def forward(
|
@@ -1524,8 +1421,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
1524 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1525 |
|
1526 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1527 |
-
outputs_include_causal_mask = self.multi_token_heads is not None
|
1528 |
-
outputs_include_position_embeddings = self.multi_token_heads is not None
|
1529 |
outputs: MotifModelOutputWithPast = self.model(
|
1530 |
input_ids=input_ids,
|
1531 |
attention_mask=attention_mask,
|
@@ -1537,23 +1432,10 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
1537 |
output_hidden_states=output_hidden_states,
|
1538 |
return_dict=return_dict,
|
1539 |
cache_position=cache_position,
|
1540 |
-
outputs_include_causal_mask=outputs_include_causal_mask,
|
1541 |
-
outputs_include_position_embeddings=outputs_include_position_embeddings,
|
1542 |
)
|
1543 |
|
1544 |
hidden_states = outputs[0]
|
1545 |
|
1546 |
-
if self.multi_token_heads is not None:
|
1547 |
-
return self.multi_token_forward_backward(hidden_states,
|
1548 |
-
outputs,
|
1549 |
-
labels,
|
1550 |
-
position_ids,
|
1551 |
-
output_attentions,
|
1552 |
-
use_cache,
|
1553 |
-
cache_position,
|
1554 |
-
return_dict,
|
1555 |
-
num_logits_to_keep=num_logits_to_keep)
|
1556 |
-
|
1557 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
1558 |
hidden_states = hidden_states
|
1559 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
|
1040 |
super().__init__(config)
|
1041 |
self.padding_idx = config.pad_token_id
|
1042 |
self.vocab_size = config.vocab_size
|
|
|
1043 |
|
1044 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1045 |
# NOTE: For multi-token models, the last decoder layers (one for each token index)
|
1046 |
# are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
|
1047 |
|
1048 |
+
num_hidden_layers = config.num_hidden_layers
|
1049 |
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
1050 |
self._attn_implementation = config._attn_implementation
|
1051 |
RMSNorm = MorehRMSNorm
|
|
|
1337 |
super().__init__(config)
|
1338 |
self.model = MotifModel(config)
|
1339 |
self.vocab_size = config.vocab_size
|
|
|
1340 |
|
1341 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1342 |
|
1343 |
# Initialize weights and apply final processing
|
1344 |
self.post_init()
|
|
|
1365 |
def get_decoder(self):
|
1366 |
return self.model
|
1367 |
|
1368 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1369 |
@add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
|
1370 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1371 |
def forward(
|
|
|
1421 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1422 |
|
1423 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
|
|
1424 |
outputs: MotifModelOutputWithPast = self.model(
|
1425 |
input_ids=input_ids,
|
1426 |
attention_mask=attention_mask,
|
|
|
1432 |
output_hidden_states=output_hidden_states,
|
1433 |
return_dict=return_dict,
|
1434 |
cache_position=cache_position,
|
|
|
|
|
1435 |
)
|
1436 |
|
1437 |
hidden_states = outputs[0]
|
1438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1439 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
1440 |
hidden_states = hidden_states
|
1441 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|