aquiffoo commited on
Commit
a089044
·
verified ·
1 Parent(s): c58e471

Upload 7 files

Browse files
Files changed (7) hide show
  1. crossexpertattention.py +40 -0
  2. meshconfig.py +64 -0
  3. meshexpert.py +17 -0
  4. meshlayer.py +55 -0
  5. meshmodel.py +88 -0
  6. meshrouter.py +27 -0
  7. neighborexchange.py +82 -0
crossexpertattention.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
7
+
8
+ # Define the Cross-Expert Attention mechanism
9
+ class CrossExpertAttention(nn.Module):
10
+ def __init__(self, config: MeshConfig):
11
+ super().__init__()
12
+ self.config = config
13
+ # Define multi-head attention layers or similar for cross-expert communication
14
+ # This is a placeholder and needs detailed implementation
15
+ self.cross_attention = nn.MultiheadAttention(
16
+ embed_dim=config.hidden_size,
17
+ num_heads=config.num_attention_heads, # Using model's attention heads for now
18
+ batch_first=True
19
+ )
20
+
21
+ def forward(self, expert_outputs):
22
+ # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
23
+
24
+ if not self.config.cross_expert_attention_enabled:
25
+ return expert_outputs
26
+
27
+ # Reshape for attention: (batch_size * sequence_length, num_experts, hidden_size)
28
+ batch_seq_size = expert_outputs.shape[0] * expert_outputs.shape[1]
29
+ reshaped_outputs = expert_outputs.view(batch_seq_size, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size)
30
+
31
+ # Apply cross-expert attention. Query, Key, Value are the same here (self-attention across experts)
32
+ # Attention mask could be used to restrict communication if needed
33
+ cross_attn_output, _ = self.cross_attention(reshaped_outputs, reshaped_outputs, reshaped_outputs)
34
+
35
+ # Reshape back: (batch_size, sequence_length, num_experts, hidden_size)
36
+ cross_attn_output = cross_attn_output.view(
37
+ expert_outputs.shape[0], expert_outputs.shape[1], self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size
38
+ )
39
+
40
+ return cross_attn_output
meshconfig.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+
8
+ class MeshConfig(PretrainedConfig):
9
+ model_type = "mesh"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=32000,
14
+ hidden_size=768,
15
+ intermediate_size=2048,
16
+ num_hidden_layers=12,
17
+ num_attention_heads=12,
18
+ num_key_value_heads=12,
19
+ max_position_embeddings=4096,
20
+ initializer_range=0.02,
21
+ rms_norm_eps=1e-6,
22
+ use_cache=True,
23
+ pad_token_id=0,
24
+ bos_token_id=1,
25
+ eos_token_id=2,
26
+ tie_word_embeddings=False,
27
+ # Mesh specific configurations
28
+ mesh_grid_size=(2, 2), # 2x2 grid
29
+ expert_intermediate_size=256, # Example size for expert intermediate layer
30
+ routing_k=2, # Top-k routing
31
+ neighbor_exchange_enabled=True,
32
+ cross_expert_attention_enabled=True,
33
+ **kwargs
34
+ ):
35
+ super().__init__(
36
+ vocab_size=vocab_size,
37
+ hidden_size=hidden_size,
38
+ intermediate_size=intermediate_size,
39
+ num_hidden_layers=num_hidden_layers,
40
+ num_attention_heads=num_attention_heads,
41
+ num_key_value_heads=num_key_value_heads,
42
+ max_position_embeddings=max_position_embeddings,
43
+ initializer_range=initializer_range,
44
+ rms_norm_eps=rms_norm_eps,
45
+ use_cache=use_cache,
46
+ pad_token_id=pad_token_id,
47
+ bos_token_id=bos_token_id,
48
+ eos_token_id=eos_token_id,
49
+ tie_word_embeddings=tie_word_embeddings,
50
+ **kwargs,
51
+ )
52
+ self.mesh_grid_size = mesh_grid_size
53
+ # Calculate expert_intermediate_size based on the shared and expert parameter split
54
+ # Total parameters = Shared (Embedding, Norm, LM Head) + Experts + Overhead
55
+ # This calculation is complex and depends on the specific layer mapping.
56
+ # For now, let's use a placeholder or calculate it based on the target parameter count.
57
+ # Target A242M (top-2): 100M shared + 135M (2 experts) + 7M overhead = 242M
58
+ # Let's assume the 135M for 2 experts is primarily in the intermediate size.
59
+ # We need to determine how Gemma's intermediate size maps to the expert intermediate size.
60
+ # For now, I will keep a placeholder or a simple ratio.
61
+ self.expert_intermediate_size = intermediate_size // (mesh_grid_size[0] * mesh_grid_size[1]) # Example: divide intermediate size by number of experts
62
+ self.routing_k = routing_k
63
+ self.neighbor_exchange_enabled = neighbor_exchange_enabled
64
+ self.cross_expert_attention_enabled = cross_expert_attention_enabled
meshexpert.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
7
+
8
+ # Define a single Expert within the Mesh
9
+ class MeshExpert(nn.Module):
10
+ def __init__(self, config: MeshConfig):
11
+ super().__init__()
12
+ self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size)
13
+ self.gelu = nn.GELU() # Using GELU as an example activation
14
+ self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size)
15
+
16
+ def forward(self, x):
17
+ return self.fc2(self.gelu(self.fc1(x)))
meshlayer.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
7
+
8
+ # Define the main Mesh Layer
9
+ class MeshLayer(nn.Module):
10
+ def __init__(self, config: MeshConfig):
11
+ super().__init__()
12
+ self.config = config
13
+ self.router = MeshRouter(config)
14
+ self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
15
+ self.neighbor_exchange = NeighborExchange(config)
16
+ self.cross_expert_attention = CrossExpertAttention(config)
17
+
18
+ def forward(self, hidden_states):
19
+ # hidden_states shape: (batch_size, sequence_length, hidden_size)
20
+
21
+ # 1. Routing
22
+ topk_weights, topk_indices = self.router(hidden_states)
23
+ # topk_weights shape: (batch_size, sequence_length, k)
24
+ # topk_indices shape: (batch_size, sequence_length, k)
25
+
26
+ # Prepare expert inputs: repeat hidden_states for each expert
27
+ # shape: (batch_size, sequence_length, num_experts, hidden_size)
28
+ expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)
29
+
30
+ # 2. Expert Computation
31
+ # Compute output for all experts (can be optimized to only compute for selected experts)
32
+ expert_outputs = torch.stack([expert(expanded_hidden_states[:, :, i, :]) for i, expert in enumerate(self.experts)], dim=2)
33
+ # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
34
+
35
+ # 3. Neighbor Exchange (conceptual implementation needed)
36
+ exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
37
+
38
+ # 4. Cross-Expert Attention (conceptual implementation needed)
39
+ cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)
40
+
41
+ # 5. Combine expert outputs based on routing weights
42
+ # Create a tensor to gather the outputs of the selected experts
43
+ # shape: (batch_size, sequence_length, k, hidden_size)
44
+ gathered_outputs = torch.gather(
45
+ cross_attned_expert_outputs,
46
+ dim=2,
47
+ index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
48
+ )
49
+
50
+ # Apply routing weights: (batch_size, sequence_length, k, 1) * (batch_size, sequence_length, k, hidden_size)
51
+ combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)
52
+ # combined_output shape: (batch_size, sequence_length, hidden_size)
53
+
54
+ # Return the combined output and the expert indices for potential visualization
55
+ return combined_output, topk_indices # Return combined output and expert indices
meshmodel.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+
8
+ class MeshModel(PreTrainedModel):
9
+ config_class = MeshConfig
10
+
11
+ def __init__(self, config: MeshConfig):
12
+ super().__init__(config)
13
+ self.config = config
14
+
15
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
16
+ self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
17
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
18
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
19
+
20
+ self.post_init()
21
+
22
+ def forward(
23
+ self,
24
+ input_ids=None,
25
+ attention_mask=None,
26
+ token_type_ids=None,
27
+ position_ids=None,
28
+ head_mask=None,
29
+ inputs_embeds=None,
30
+ encoder_hidden_states=None,
31
+ encoder_attention_mask=None,
32
+ labels=None,
33
+ past_key_values=None,
34
+ use_cache=None,
35
+ output_attentions=None,
36
+ output_hidden_states=None,
37
+ return_dict=None,
38
+ ):
39
+ # Ensure return_dict is set to True by default if not specified
40
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
41
+
42
+ if input_ids is not None and inputs_embeds is not None:
43
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
44
+ elif input_ids is not None:
45
+ input_shape = input_ids.size()
46
+ inputs_embeds = self.embedding(input_ids)
47
+ elif inputs_embeds is not None:
48
+ input_shape = inputs_embeds.size()[:-1]
49
+ else:
50
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
51
+
52
+ hidden_states = inputs_embeds
53
+ expert_indices_list = [] # To collect expert indices from each layer
54
+
55
+ for i, layer in enumerate(self.layers):
56
+ hidden_states, expert_indices = layer(hidden_states)
57
+ expert_indices_list.append(expert_indices) # Collect indices
58
+
59
+ hidden_states = self.norm(hidden_states)
60
+ logits = self.lm_head(hidden_states)
61
+
62
+ loss = None
63
+ if labels is not None:
64
+ # Compute loss (e.g., CrossEntropyLoss)
65
+ loss_fct = nn.CrossEntropyLoss()
66
+ # Shift so that tokens < n predict n
67
+ shift_logits = logits[..., :-1, :].contiguous()
68
+ shift_labels = labels[..., 1:].contiguous()
69
+ # Calculate scalar loss
70
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
71
+
72
+ # Return a CausalLMOutputWithPast object or a tuple
73
+ if return_dict:
74
+ return CausalLMOutputWithPast(
75
+ loss=loss,
76
+ logits=logits,
77
+ past_key_values=None, # Need to implement caching
78
+ hidden_states=hidden_states,
79
+ attentions=None, # Need to implement attention handling
80
+ )
81
+ else:
82
+ # Return a tuple including loss, logits, and collected expert indices
83
+ # Ensure the order and content match what the Trainer expects or can handle
84
+ # Trainer expects (loss, logits, hidden_states, attentions) or similar.
85
+ # We can return (loss, logits) as the primary outputs for the Trainer
86
+ # and potentially include expert_indices as an additional output if needed
87
+ # by a custom callback or logging, but the default Trainer expects loss as the first element for backward.
88
+ return (loss, logits, hidden_states, expert_indices_list) # Include expert_indices_list
meshrouter.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
7
+
8
+ # Define the Router for dynamic routing
9
+ class MeshRouter(nn.Module):
10
+ def __init__(self, config: MeshConfig):
11
+ super().__init__()
12
+ self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
13
+ self.softmax = nn.Softmax(dim=-1)
14
+ self.routing_k = config.routing_k
15
+
16
+ def forward(self, x):
17
+ # x shape: (batch_size, sequence_length, hidden_size)
18
+ gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
19
+ gate_weights = self.softmax(gate_scores)
20
+
21
+ # Select top-k experts
22
+ topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
23
+
24
+ # Normalize top-k weights
25
+ topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)
26
+
27
+ return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)
neighborexchange.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class NeighborExchange(nn.Module):
5
+ def __init__(self, config: MeshConfig):
6
+ super().__init__()
7
+ self.config = config
8
+ self.num_experts_x = config.mesh_grid_size[0]
9
+ self.num_experts_y = config.mesh_grid_size[1]
10
+ self.num_experts = self.num_experts_x * self.num_experts_y
11
+
12
+ # Define parameters for neighbor communication.
13
+ # A simple approach: a learned linear combination of neighbor features.
14
+ # We can define a weight for each potential neighbor direction (e.g., up, down, left, right).
15
+ # For a 2x2 grid, each expert has 2 or 3 neighbors.
16
+ # A more general approach is a linear layer that takes concatenated neighbor features.
17
+ # Let's use a linear layer to transform the aggregated neighbor information.
18
+ # The input size to this layer will be the sum of hidden sizes of all potential neighbors
19
+ # multiplied by the hidden size, but that's too complex.
20
+ # A simpler approach: a linear layer per direction, or a single layer after aggregating.
21
+
22
+ # Let's define a linear layer to process the information received from neighbors.
23
+ # The input size is the hidden size (from neighbors), output size is hidden size
24
+ # This layer will transform the aggregated neighbor features before adding to the expert's own output.
25
+ self.exchange_projection = nn.Linear(config.hidden_size, config.hidden_size) # Projects aggregated neighbor info
26
+
27
+ # Optional: Learned weights for different neighbor directions
28
+ # self.neighbor_weights = nn.Parameter(torch.ones(4)) # Example for 4 directions (N, S, E, W)
29
+
30
+ def forward(self, expert_outputs, expert_indices=None):
31
+ # expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
32
+ # expert_indices shape: (batch_size, sequence_length, k) - indices of selected experts (not directly used for neighbor exchange in this simple model)
33
+
34
+ if not self.config.neighbor_exchange_enabled:
35
+ return expert_outputs
36
+
37
+ batch_size, seq_length, num_experts, hidden_size = expert_outputs.shape
38
+
39
+ # Reshape expert_outputs to reflect the grid structure (batch_size, seq_length, grid_x, grid_y, hidden_size)
40
+ reshaped_outputs = expert_outputs.view(batch_size, seq_length, self.num_experts_x, self.num_experts_y, hidden_size)
41
+
42
+ # Create a tensor to store the aggregated neighbor information for each expert
43
+ aggregated_neighbor_info = torch.zeros_like(reshaped_outputs)
44
+
45
+ # Implement neighbor exchange logic
46
+ # Iterate through each expert in the grid
47
+ for i in range(self.num_experts_x):
48
+ for j in range(self.num_experts_y):
49
+ current_expert_output = reshaped_outputs[:, :, i, j, :]
50
+ neighbor_info = torch.zeros_like(current_expert_output) # Accumulate info from neighbors
51
+
52
+ # Define neighbor directions (example: up, down, left, right)
53
+ neighbors = []
54
+ if i > 0: # Up neighbor
55
+ neighbors.append(reshaped_outputs[:, :, i-1, j, :])
56
+ if i < self.num_experts_x - 1: # Down neighbor
57
+ neighbors.append(reshaped_outputs[:, :, i+1, j, :])
58
+ if j > 0: # Left neighbor
59
+ neighbors.append(reshaped_outputs[:, :, i, j-1, :])
60
+ if j < self.num_experts_y - 1: # Right neighbor
61
+ neighbors.append(reshaped_outputs[:, :, i, j+1, :])
62
+
63
+ # Aggregate information from neighbors (simple average as an example)
64
+ if neighbors:
65
+ # Stack neighbors along a new dimension and take the mean
66
+ neighbor_stack = torch.stack(neighbors, dim=-2) # shape (batch, seq, num_neighbors, hidden)
67
+ aggregated_info = torch.mean(neighbor_stack, dim=-2) # shape (batch, seq, hidden)
68
+ neighbor_info = aggregated_info # Use the aggregated info
69
+
70
+ # Apply the exchange projection to the aggregated neighbor information
71
+ transformed_neighbor_info = self.exchange_projection(neighbor_info)
72
+
73
+ # Store the transformed neighbor info for the current expert's position
74
+ aggregated_neighbor_info[:, :, i, j, :] = transformed_neighbor_info
75
+
76
+ # Reshape aggregated_neighbor_info back to (batch_size, sequence_length, num_experts, hidden_size)
77
+ aggregated_neighbor_info = aggregated_neighbor_info.view(batch_size, seq_length, num_experts, hidden_size)
78
+
79
+ # Combine expert outputs with aggregated neighbor information (additive combination)
80
+ exchanged_expert_outputs = expert_outputs + aggregated_neighbor_info
81
+
82
+ return exchanged_expert_outputs