aquiffoo commited on
Commit
4765ac9
·
verified ·
1 Parent(s): 0bf67d2

Update meshmodel.py

Browse files
Files changed (1) hide show
  1. meshmodel.py +87 -2
meshmodel.py CHANGED
@@ -1,3 +1,88 @@
1
- # Source code for MeshModel from cell VExhmWA0lXA_
2
- # Please replace this with the actual code from the notebook cell.
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
 
8
+ class MeshModel(PreTrainedModel):
9
+ config_class = MeshConfig
10
+
11
+ def __init__(self, config: MeshConfig):
12
+ super().__init__(config)
13
+ self.config = config
14
+
15
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
16
+ self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
17
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
18
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
19
+
20
+ self.post_init()
21
+
22
+ def forward(
23
+ self,
24
+ input_ids=None,
25
+ attention_mask=None,
26
+ token_type_ids=None,
27
+ position_ids=None,
28
+ head_mask=None,
29
+ inputs_embeds=None,
30
+ encoder_hidden_states=None,
31
+ encoder_attention_mask=None,
32
+ labels=None,
33
+ past_key_values=None,
34
+ use_cache=None,
35
+ output_attentions=None,
36
+ output_hidden_states=None,
37
+ return_dict=None,
38
+ ):
39
+ # Ensure return_dict is set to True by default if not specified
40
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
41
+
42
+ if input_ids is not None and inputs_embeds is not None:
43
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
44
+ elif input_ids is not None:
45
+ input_shape = input_ids.size()
46
+ inputs_embeds = self.embedding(input_ids)
47
+ elif inputs_embeds is not None:
48
+ input_shape = inputs_embeds.size()[:-1]
49
+ else:
50
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
51
+
52
+ hidden_states = inputs_embeds
53
+ expert_indices_list = [] # To collect expert indices from each layer
54
+
55
+ for i, layer in enumerate(self.layers):
56
+ hidden_states, expert_indices = layer(hidden_states)
57
+ expert_indices_list.append(expert_indices) # Collect indices
58
+
59
+ hidden_states = self.norm(hidden_states)
60
+ logits = self.lm_head(hidden_states)
61
+
62
+ loss = None
63
+ if labels is not None:
64
+ # Compute loss (e.g., CrossEntropyLoss)
65
+ loss_fct = nn.CrossEntropyLoss()
66
+ # Shift so that tokens < n predict n
67
+ shift_logits = logits[..., :-1, :].contiguous()
68
+ shift_labels = labels[..., 1:].contiguous()
69
+ # Calculate scalar loss
70
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
71
+
72
+ # Return a CausalLMOutputWithPast object or a tuple
73
+ if return_dict:
74
+ return CausalLMOutputWithPast(
75
+ loss=loss,
76
+ logits=logits,
77
+ past_key_values=None, # Need to implement caching
78
+ hidden_states=hidden_states,
79
+ attentions=None, # Need to implement attention handling
80
+ )
81
+ else:
82
+ # Return a tuple including loss, logits, and collected expert indices
83
+ # Ensure the order and content match what the Trainer expects or can handle
84
+ # Trainer expects (loss, logits, hidden_states, attentions) or similar.
85
+ # We can return (loss, logits) as the primary outputs for the Trainer
86
+ # and potentially include expert_indices as an additional output if needed
87
+ # by a custom callback or logging, but the default Trainer expects loss as the first element for backward.
88
+ return (loss, logits, hidden_states, expert_indices_list) # Include expert_indices_list