leejunhyeok commited on
Commit
d56ef75
·
verified ·
1 Parent(s): 7d2479a

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +3 -121
modeling_motif.py CHANGED
@@ -1040,13 +1040,12 @@ class MotifModel(MotifPreTrainedModel):
1040
  super().__init__(config)
1041
  self.padding_idx = config.pad_token_id
1042
  self.vocab_size = config.vocab_size
1043
- self.multi_token_heads = config.multi_token_heads
1044
 
1045
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1046
  # NOTE: For multi-token models, the last decoder layers (one for each token index)
1047
  # are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
1048
 
1049
- num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
1050
  self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
1051
  self._attn_implementation = config._attn_implementation
1052
  RMSNorm = MorehRMSNorm
@@ -1338,16 +1337,8 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1338
  super().__init__(config)
1339
  self.model = MotifModel(config)
1340
  self.vocab_size = config.vocab_size
1341
- self.multi_token_heads = config.multi_token_heads
1342
 
1343
- if self.multi_token_heads is None:
1344
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1345
- else:
1346
- self.tokenwise_last_layers = nn.ModuleList(
1347
- [MotifDecoderLayer(config, config.num_hidden_layers - 1) for _ in range(self.multi_token_heads)])
1348
- self.tokenwise_lm_heads = nn.ModuleList(
1349
- [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(self.multi_token_heads)])
1350
- self.should_skip_separate_backward_pass = self.multi_token_heads is not None
1351
 
1352
  # Initialize weights and apply final processing
1353
  self.post_init()
@@ -1374,101 +1365,7 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1374
  def get_decoder(self):
1375
  return self.model
1376
 
1377
- def multi_token_forward_backward(self,
1378
- hidden_states: torch.FloatTensor,
1379
- outputs: MotifModelOutputWithPast,
1380
- labels: torch.LongTensor,
1381
- position_ids: Optional[torch.LongTensor],
1382
- output_attentions: Optional[bool],
1383
- use_cache: Optional[bool],
1384
- cache_position: Optional[torch.LongTensor],
1385
- return_dict: Optional[bool],
1386
- num_logits_to_keep: int = 0) -> CausalLMOutputWithPast:
1387
- """
1388
- This implements the main forward-backward procedure for multi-token model training proposed in
1389
- the paper https://arxiv.org/abs/2404.19737.
1390
- Essentially,
1391
- - The multi-token model tries to predict n (instead of 1) tokens at a time.
1392
- - Applying this only during training and using first-token prediction during inference is still helpful.
1393
- - The change in architecture: when using n-token prediction, each token index (between 1 and n) has its own
1394
- (1) last attention layer and (2) lm head.
1395
- - The change in loss: sum of cross-entropy losses corresponding to each token index.
1396
- - Custom forward-backward procedure for memory efficiency: refer to the implementation of `multi_head_forward_backward`.
1397
- """
1398
- if not return_dict:
1399
- raise NotImplementedError("return_dict must be True for multi-token training")
1400
-
1401
- past_key_values = outputs.past_key_values
1402
- causal_mask = outputs.causal_mask
1403
- position_embeddings = outputs.position_embeddings
1404
-
1405
- if labels is not None:
1406
- labels = labels.to(hidden_states.device)
1407
-
1408
- def _tokenwise_forward(hidden_states: torch.Tensor, token_idx):
1409
- ## Model forward
1410
- layer = self.tokenwise_last_layers[token_idx]
1411
- lm_head = self.tokenwise_lm_heads[token_idx]
1412
-
1413
- layer_outputs = layer(
1414
- hidden_states,
1415
- attention_mask=causal_mask,
1416
- position_ids=position_ids,
1417
- past_key_values=past_key_values, # TODO: update past_key_values?
1418
- output_attentions=output_attentions,
1419
- use_cache=use_cache,
1420
- cache_position=cache_position,
1421
- position_embeddings=position_embeddings,
1422
- )
1423
- last_hidden_states = layer_outputs[0]
1424
- if num_logits_to_keep > 0:
1425
- assert labels is None
1426
- last_hidden_states = last_hidden_states[:, -num_logits_to_keep:, :]
1427
- tokenwise_logits = lm_head(last_hidden_states)
1428
-
1429
- if labels is None:
1430
- return {
1431
- "loss": None,
1432
- "logits": tokenwise_logits,
1433
- }
1434
-
1435
- ## Compute loss
1436
- shift_n = token_idx + 1
1437
- shift_logits = tokenwise_logits[..., :-shift_n, :].contiguous()
1438
- shift_labels = labels[..., shift_n:].contiguous()
1439
-
1440
- loss_fct = CrossEntropyLoss()
1441
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1442
- shift_labels = shift_labels.view(-1)
1443
-
1444
- tokenwise_loss = loss_fct(shift_logits, shift_labels)
1445
-
1446
- return {
1447
- "loss": tokenwise_loss,
1448
- "logits": tokenwise_logits,
1449
- }
1450
-
1451
- head_fns = [
1452
- lambda hidden_states, token_idx=token_idx: _tokenwise_forward(hidden_states, token_idx)
1453
- for token_idx in range(self.multi_token_heads)
1454
- ]
1455
- loss, logits = multi_head_forward_backward(hidden_states,
1456
- head_fns,
1457
- return_keys=("loss", "logits"),
1458
- return_only_first_head=True)
1459
-
1460
- if not return_dict:
1461
- output = (logits, ) + outputs[1:]
1462
- return (loss, ) + output
1463
-
1464
- return CausalLMOutputWithPast(
1465
- loss=loss,
1466
- logits=logits,
1467
- past_key_values=outputs.past_key_values,
1468
- hidden_states=outputs.hidden_states,
1469
- attentions=outputs.attentions,
1470
- )
1471
-
1472
  @add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
1473
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1474
  def forward(
@@ -1524,8 +1421,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1524
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1525
 
1526
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1527
- outputs_include_causal_mask = self.multi_token_heads is not None
1528
- outputs_include_position_embeddings = self.multi_token_heads is not None
1529
  outputs: MotifModelOutputWithPast = self.model(
1530
  input_ids=input_ids,
1531
  attention_mask=attention_mask,
@@ -1537,23 +1432,10 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1537
  output_hidden_states=output_hidden_states,
1538
  return_dict=return_dict,
1539
  cache_position=cache_position,
1540
- outputs_include_causal_mask=outputs_include_causal_mask,
1541
- outputs_include_position_embeddings=outputs_include_position_embeddings,
1542
  )
1543
 
1544
  hidden_states = outputs[0]
1545
 
1546
- if self.multi_token_heads is not None:
1547
- return self.multi_token_forward_backward(hidden_states,
1548
- outputs,
1549
- labels,
1550
- position_ids,
1551
- output_attentions,
1552
- use_cache,
1553
- cache_position,
1554
- return_dict,
1555
- num_logits_to_keep=num_logits_to_keep)
1556
-
1557
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1558
  hidden_states = hidden_states
1559
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
 
1040
  super().__init__(config)
1041
  self.padding_idx = config.pad_token_id
1042
  self.vocab_size = config.vocab_size
 
1043
 
1044
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1045
  # NOTE: For multi-token models, the last decoder layers (one for each token index)
1046
  # are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
1047
 
1048
+ num_hidden_layers = config.num_hidden_layers
1049
  self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
1050
  self._attn_implementation = config._attn_implementation
1051
  RMSNorm = MorehRMSNorm
 
1337
  super().__init__(config)
1338
  self.model = MotifModel(config)
1339
  self.vocab_size = config.vocab_size
 
1340
 
1341
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
 
 
 
1342
 
1343
  # Initialize weights and apply final processing
1344
  self.post_init()
 
1365
  def get_decoder(self):
1366
  return self.model
1367
 
1368
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1369
  @add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
1370
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1371
  def forward(
 
1421
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1422
 
1423
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 
 
1424
  outputs: MotifModelOutputWithPast = self.model(
1425
  input_ids=input_ids,
1426
  attention_mask=attention_mask,
 
1432
  output_hidden_states=output_hidden_states,
1433
  return_dict=return_dict,
1434
  cache_position=cache_position,
 
 
1435
  )
1436
 
1437
  hidden_states = outputs[0]
1438
 
 
 
 
 
 
 
 
 
 
 
 
1439
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1440
  hidden_states = hidden_states
1441
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])