aquiffoo commited on
Commit
7ce8e21
·
verified ·
1 Parent(s): 4765ac9

Update meshlayer.py

Browse files
Files changed (1) hide show
  1. meshlayer.py +54 -2
meshlayer.py CHANGED
@@ -1,3 +1,55 @@
1
- # Source code for MeshLayer from cell 39f2782d
2
- # Please replace this with the actual code from the notebook cell.
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
7
 
8
+ # Define the main Mesh Layer
9
+ class MeshLayer(nn.Module):
10
+ def __init__(self, config: MeshConfig):
11
+ super().__init__()
12
+ self.config = config
13
+ self.router = MeshRouter(config)
14
+ self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
15
+ self.neighbor_exchange = NeighborExchange(config)
16
+ self.cross_expert_attention = CrossExpertAttention(config)
17
+
18
+ def forward(self, hidden_states):
19
+ # hidden_states shape: (batch_size, sequence_length, hidden_size)
20
+
21
+ # 1. Routing
22
+ topk_weights, topk_indices = self.router(hidden_states)
23
+ # topk_weights shape: (batch_size, sequence_length, k)
24
+ # topk_indices shape: (batch_size, sequence_length, k)
25
+
26
+ # Prepare expert inputs: repeat hidden_states for each expert
27
+ # shape: (batch_size, sequence_length, num_experts, hidden_size)
28
+ expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)
29
+
30
+ # 2. Expert Computation
31
+ # Compute output for all experts (can be optimized to only compute for selected experts)
32
+ expert_outputs = torch.stack([expert(expanded_hidden_states[:, :, i, :]) for i, expert in enumerate(self.experts)], dim=2)
33
+ # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
34
+
35
+ # 3. Neighbor Exchange (conceptual implementation needed)
36
+ exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
37
+
38
+ # 4. Cross-Expert Attention (conceptual implementation needed)
39
+ cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)
40
+
41
+ # 5. Combine expert outputs based on routing weights
42
+ # Create a tensor to gather the outputs of the selected experts
43
+ # shape: (batch_size, sequence_length, k, hidden_size)
44
+ gathered_outputs = torch.gather(
45
+ cross_attned_expert_outputs,
46
+ dim=2,
47
+ index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
48
+ )
49
+
50
+ # Apply routing weights: (batch_size, sequence_length, k, 1) * (batch_size, sequence_length, k, hidden_size)
51
+ combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)
52
+ # combined_output shape: (batch_size, sequence_length, hidden_size)
53
+
54
+ # Return the combined output and the expert indices for potential visualization
55
+ return combined_output, topk_indices # Return combined output and expert indices