Upload 7 files
Browse files- crossexpertattention.py +40 -0
- meshconfig.py +64 -0
- meshexpert.py +17 -0
- meshlayer.py +55 -0
- meshmodel.py +88 -0
- meshrouter.py +27 -0
- neighborexchange.py +82 -0
crossexpertattention.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
|
7 |
+
|
8 |
+
# Define the Cross-Expert Attention mechanism
|
9 |
+
class CrossExpertAttention(nn.Module):
|
10 |
+
def __init__(self, config: MeshConfig):
|
11 |
+
super().__init__()
|
12 |
+
self.config = config
|
13 |
+
# Define multi-head attention layers or similar for cross-expert communication
|
14 |
+
# This is a placeholder and needs detailed implementation
|
15 |
+
self.cross_attention = nn.MultiheadAttention(
|
16 |
+
embed_dim=config.hidden_size,
|
17 |
+
num_heads=config.num_attention_heads, # Using model's attention heads for now
|
18 |
+
batch_first=True
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, expert_outputs):
|
22 |
+
# expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
|
23 |
+
|
24 |
+
if not self.config.cross_expert_attention_enabled:
|
25 |
+
return expert_outputs
|
26 |
+
|
27 |
+
# Reshape for attention: (batch_size * sequence_length, num_experts, hidden_size)
|
28 |
+
batch_seq_size = expert_outputs.shape[0] * expert_outputs.shape[1]
|
29 |
+
reshaped_outputs = expert_outputs.view(batch_seq_size, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size)
|
30 |
+
|
31 |
+
# Apply cross-expert attention. Query, Key, Value are the same here (self-attention across experts)
|
32 |
+
# Attention mask could be used to restrict communication if needed
|
33 |
+
cross_attn_output, _ = self.cross_attention(reshaped_outputs, reshaped_outputs, reshaped_outputs)
|
34 |
+
|
35 |
+
# Reshape back: (batch_size, sequence_length, num_experts, hidden_size)
|
36 |
+
cross_attn_output = cross_attn_output.view(
|
37 |
+
expert_outputs.shape[0], expert_outputs.shape[1], self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size
|
38 |
+
)
|
39 |
+
|
40 |
+
return cross_attn_output
|
meshconfig.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7 |
+
|
8 |
+
class MeshConfig(PretrainedConfig):
|
9 |
+
model_type = "mesh"
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
vocab_size=32000,
|
14 |
+
hidden_size=768,
|
15 |
+
intermediate_size=2048,
|
16 |
+
num_hidden_layers=12,
|
17 |
+
num_attention_heads=12,
|
18 |
+
num_key_value_heads=12,
|
19 |
+
max_position_embeddings=4096,
|
20 |
+
initializer_range=0.02,
|
21 |
+
rms_norm_eps=1e-6,
|
22 |
+
use_cache=True,
|
23 |
+
pad_token_id=0,
|
24 |
+
bos_token_id=1,
|
25 |
+
eos_token_id=2,
|
26 |
+
tie_word_embeddings=False,
|
27 |
+
# Mesh specific configurations
|
28 |
+
mesh_grid_size=(2, 2), # 2x2 grid
|
29 |
+
expert_intermediate_size=256, # Example size for expert intermediate layer
|
30 |
+
routing_k=2, # Top-k routing
|
31 |
+
neighbor_exchange_enabled=True,
|
32 |
+
cross_expert_attention_enabled=True,
|
33 |
+
**kwargs
|
34 |
+
):
|
35 |
+
super().__init__(
|
36 |
+
vocab_size=vocab_size,
|
37 |
+
hidden_size=hidden_size,
|
38 |
+
intermediate_size=intermediate_size,
|
39 |
+
num_hidden_layers=num_hidden_layers,
|
40 |
+
num_attention_heads=num_attention_heads,
|
41 |
+
num_key_value_heads=num_key_value_heads,
|
42 |
+
max_position_embeddings=max_position_embeddings,
|
43 |
+
initializer_range=initializer_range,
|
44 |
+
rms_norm_eps=rms_norm_eps,
|
45 |
+
use_cache=use_cache,
|
46 |
+
pad_token_id=pad_token_id,
|
47 |
+
bos_token_id=bos_token_id,
|
48 |
+
eos_token_id=eos_token_id,
|
49 |
+
tie_word_embeddings=tie_word_embeddings,
|
50 |
+
**kwargs,
|
51 |
+
)
|
52 |
+
self.mesh_grid_size = mesh_grid_size
|
53 |
+
# Calculate expert_intermediate_size based on the shared and expert parameter split
|
54 |
+
# Total parameters = Shared (Embedding, Norm, LM Head) + Experts + Overhead
|
55 |
+
# This calculation is complex and depends on the specific layer mapping.
|
56 |
+
# For now, let's use a placeholder or calculate it based on the target parameter count.
|
57 |
+
# Target A242M (top-2): 100M shared + 135M (2 experts) + 7M overhead = 242M
|
58 |
+
# Let's assume the 135M for 2 experts is primarily in the intermediate size.
|
59 |
+
# We need to determine how Gemma's intermediate size maps to the expert intermediate size.
|
60 |
+
# For now, I will keep a placeholder or a simple ratio.
|
61 |
+
self.expert_intermediate_size = intermediate_size // (mesh_grid_size[0] * mesh_grid_size[1]) # Example: divide intermediate size by number of experts
|
62 |
+
self.routing_k = routing_k
|
63 |
+
self.neighbor_exchange_enabled = neighbor_exchange_enabled
|
64 |
+
self.cross_expert_attention_enabled = cross_expert_attention_enabled
|
meshexpert.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
|
7 |
+
|
8 |
+
# Define a single Expert within the Mesh
|
9 |
+
class MeshExpert(nn.Module):
|
10 |
+
def __init__(self, config: MeshConfig):
|
11 |
+
super().__init__()
|
12 |
+
self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size)
|
13 |
+
self.gelu = nn.GELU() # Using GELU as an example activation
|
14 |
+
self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return self.fc2(self.gelu(self.fc1(x)))
|
meshlayer.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
|
7 |
+
|
8 |
+
# Define the main Mesh Layer
|
9 |
+
class MeshLayer(nn.Module):
|
10 |
+
def __init__(self, config: MeshConfig):
|
11 |
+
super().__init__()
|
12 |
+
self.config = config
|
13 |
+
self.router = MeshRouter(config)
|
14 |
+
self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
|
15 |
+
self.neighbor_exchange = NeighborExchange(config)
|
16 |
+
self.cross_expert_attention = CrossExpertAttention(config)
|
17 |
+
|
18 |
+
def forward(self, hidden_states):
|
19 |
+
# hidden_states shape: (batch_size, sequence_length, hidden_size)
|
20 |
+
|
21 |
+
# 1. Routing
|
22 |
+
topk_weights, topk_indices = self.router(hidden_states)
|
23 |
+
# topk_weights shape: (batch_size, sequence_length, k)
|
24 |
+
# topk_indices shape: (batch_size, sequence_length, k)
|
25 |
+
|
26 |
+
# Prepare expert inputs: repeat hidden_states for each expert
|
27 |
+
# shape: (batch_size, sequence_length, num_experts, hidden_size)
|
28 |
+
expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)
|
29 |
+
|
30 |
+
# 2. Expert Computation
|
31 |
+
# Compute output for all experts (can be optimized to only compute for selected experts)
|
32 |
+
expert_outputs = torch.stack([expert(expanded_hidden_states[:, :, i, :]) for i, expert in enumerate(self.experts)], dim=2)
|
33 |
+
# expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
|
34 |
+
|
35 |
+
# 3. Neighbor Exchange (conceptual implementation needed)
|
36 |
+
exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
|
37 |
+
|
38 |
+
# 4. Cross-Expert Attention (conceptual implementation needed)
|
39 |
+
cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)
|
40 |
+
|
41 |
+
# 5. Combine expert outputs based on routing weights
|
42 |
+
# Create a tensor to gather the outputs of the selected experts
|
43 |
+
# shape: (batch_size, sequence_length, k, hidden_size)
|
44 |
+
gathered_outputs = torch.gather(
|
45 |
+
cross_attned_expert_outputs,
|
46 |
+
dim=2,
|
47 |
+
index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
|
48 |
+
)
|
49 |
+
|
50 |
+
# Apply routing weights: (batch_size, sequence_length, k, 1) * (batch_size, sequence_length, k, hidden_size)
|
51 |
+
combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)
|
52 |
+
# combined_output shape: (batch_size, sequence_length, hidden_size)
|
53 |
+
|
54 |
+
# Return the combined output and the expert indices for potential visualization
|
55 |
+
return combined_output, topk_indices # Return combined output and expert indices
|
meshmodel.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7 |
+
|
8 |
+
class MeshModel(PreTrainedModel):
|
9 |
+
config_class = MeshConfig
|
10 |
+
|
11 |
+
def __init__(self, config: MeshConfig):
|
12 |
+
super().__init__(config)
|
13 |
+
self.config = config
|
14 |
+
|
15 |
+
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
16 |
+
self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
|
17 |
+
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
18 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
19 |
+
|
20 |
+
self.post_init()
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
input_ids=None,
|
25 |
+
attention_mask=None,
|
26 |
+
token_type_ids=None,
|
27 |
+
position_ids=None,
|
28 |
+
head_mask=None,
|
29 |
+
inputs_embeds=None,
|
30 |
+
encoder_hidden_states=None,
|
31 |
+
encoder_attention_mask=None,
|
32 |
+
labels=None,
|
33 |
+
past_key_values=None,
|
34 |
+
use_cache=None,
|
35 |
+
output_attentions=None,
|
36 |
+
output_hidden_states=None,
|
37 |
+
return_dict=None,
|
38 |
+
):
|
39 |
+
# Ensure return_dict is set to True by default if not specified
|
40 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
41 |
+
|
42 |
+
if input_ids is not None and inputs_embeds is not None:
|
43 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
44 |
+
elif input_ids is not None:
|
45 |
+
input_shape = input_ids.size()
|
46 |
+
inputs_embeds = self.embedding(input_ids)
|
47 |
+
elif inputs_embeds is not None:
|
48 |
+
input_shape = inputs_embeds.size()[:-1]
|
49 |
+
else:
|
50 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
51 |
+
|
52 |
+
hidden_states = inputs_embeds
|
53 |
+
expert_indices_list = [] # To collect expert indices from each layer
|
54 |
+
|
55 |
+
for i, layer in enumerate(self.layers):
|
56 |
+
hidden_states, expert_indices = layer(hidden_states)
|
57 |
+
expert_indices_list.append(expert_indices) # Collect indices
|
58 |
+
|
59 |
+
hidden_states = self.norm(hidden_states)
|
60 |
+
logits = self.lm_head(hidden_states)
|
61 |
+
|
62 |
+
loss = None
|
63 |
+
if labels is not None:
|
64 |
+
# Compute loss (e.g., CrossEntropyLoss)
|
65 |
+
loss_fct = nn.CrossEntropyLoss()
|
66 |
+
# Shift so that tokens < n predict n
|
67 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
68 |
+
shift_labels = labels[..., 1:].contiguous()
|
69 |
+
# Calculate scalar loss
|
70 |
+
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
71 |
+
|
72 |
+
# Return a CausalLMOutputWithPast object or a tuple
|
73 |
+
if return_dict:
|
74 |
+
return CausalLMOutputWithPast(
|
75 |
+
loss=loss,
|
76 |
+
logits=logits,
|
77 |
+
past_key_values=None, # Need to implement caching
|
78 |
+
hidden_states=hidden_states,
|
79 |
+
attentions=None, # Need to implement attention handling
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
# Return a tuple including loss, logits, and collected expert indices
|
83 |
+
# Ensure the order and content match what the Trainer expects or can handle
|
84 |
+
# Trainer expects (loss, logits, hidden_states, attentions) or similar.
|
85 |
+
# We can return (loss, logits) as the primary outputs for the Trainer
|
86 |
+
# and potentially include expert_indices as an additional output if needed
|
87 |
+
# by a custom callback or logging, but the default Trainer expects loss as the first element for backward.
|
88 |
+
return (loss, logits, hidden_states, expert_indices_list) # Include expert_indices_list
|
meshrouter.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
|
7 |
+
|
8 |
+
# Define the Router for dynamic routing
|
9 |
+
class MeshRouter(nn.Module):
|
10 |
+
def __init__(self, config: MeshConfig):
|
11 |
+
super().__init__()
|
12 |
+
self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
|
13 |
+
self.softmax = nn.Softmax(dim=-1)
|
14 |
+
self.routing_k = config.routing_k
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
# x shape: (batch_size, sequence_length, hidden_size)
|
18 |
+
gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
|
19 |
+
gate_weights = self.softmax(gate_scores)
|
20 |
+
|
21 |
+
# Select top-k experts
|
22 |
+
topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
|
23 |
+
|
24 |
+
# Normalize top-k weights
|
25 |
+
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)
|
26 |
+
|
27 |
+
return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)
|
neighborexchange.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class NeighborExchange(nn.Module):
|
5 |
+
def __init__(self, config: MeshConfig):
|
6 |
+
super().__init__()
|
7 |
+
self.config = config
|
8 |
+
self.num_experts_x = config.mesh_grid_size[0]
|
9 |
+
self.num_experts_y = config.mesh_grid_size[1]
|
10 |
+
self.num_experts = self.num_experts_x * self.num_experts_y
|
11 |
+
|
12 |
+
# Define parameters for neighbor communication.
|
13 |
+
# A simple approach: a learned linear combination of neighbor features.
|
14 |
+
# We can define a weight for each potential neighbor direction (e.g., up, down, left, right).
|
15 |
+
# For a 2x2 grid, each expert has 2 or 3 neighbors.
|
16 |
+
# A more general approach is a linear layer that takes concatenated neighbor features.
|
17 |
+
# Let's use a linear layer to transform the aggregated neighbor information.
|
18 |
+
# The input size to this layer will be the sum of hidden sizes of all potential neighbors
|
19 |
+
# multiplied by the hidden size, but that's too complex.
|
20 |
+
# A simpler approach: a linear layer per direction, or a single layer after aggregating.
|
21 |
+
|
22 |
+
# Let's define a linear layer to process the information received from neighbors.
|
23 |
+
# The input size is the hidden size (from neighbors), output size is hidden size
|
24 |
+
# This layer will transform the aggregated neighbor features before adding to the expert's own output.
|
25 |
+
self.exchange_projection = nn.Linear(config.hidden_size, config.hidden_size) # Projects aggregated neighbor info
|
26 |
+
|
27 |
+
# Optional: Learned weights for different neighbor directions
|
28 |
+
# self.neighbor_weights = nn.Parameter(torch.ones(4)) # Example for 4 directions (N, S, E, W)
|
29 |
+
|
30 |
+
def forward(self, expert_outputs, expert_indices=None):
|
31 |
+
# expert_outputs shape: (batch_size, sequence_length, num_experts, hidden_size)
|
32 |
+
# expert_indices shape: (batch_size, sequence_length, k) - indices of selected experts (not directly used for neighbor exchange in this simple model)
|
33 |
+
|
34 |
+
if not self.config.neighbor_exchange_enabled:
|
35 |
+
return expert_outputs
|
36 |
+
|
37 |
+
batch_size, seq_length, num_experts, hidden_size = expert_outputs.shape
|
38 |
+
|
39 |
+
# Reshape expert_outputs to reflect the grid structure (batch_size, seq_length, grid_x, grid_y, hidden_size)
|
40 |
+
reshaped_outputs = expert_outputs.view(batch_size, seq_length, self.num_experts_x, self.num_experts_y, hidden_size)
|
41 |
+
|
42 |
+
# Create a tensor to store the aggregated neighbor information for each expert
|
43 |
+
aggregated_neighbor_info = torch.zeros_like(reshaped_outputs)
|
44 |
+
|
45 |
+
# Implement neighbor exchange logic
|
46 |
+
# Iterate through each expert in the grid
|
47 |
+
for i in range(self.num_experts_x):
|
48 |
+
for j in range(self.num_experts_y):
|
49 |
+
current_expert_output = reshaped_outputs[:, :, i, j, :]
|
50 |
+
neighbor_info = torch.zeros_like(current_expert_output) # Accumulate info from neighbors
|
51 |
+
|
52 |
+
# Define neighbor directions (example: up, down, left, right)
|
53 |
+
neighbors = []
|
54 |
+
if i > 0: # Up neighbor
|
55 |
+
neighbors.append(reshaped_outputs[:, :, i-1, j, :])
|
56 |
+
if i < self.num_experts_x - 1: # Down neighbor
|
57 |
+
neighbors.append(reshaped_outputs[:, :, i+1, j, :])
|
58 |
+
if j > 0: # Left neighbor
|
59 |
+
neighbors.append(reshaped_outputs[:, :, i, j-1, :])
|
60 |
+
if j < self.num_experts_y - 1: # Right neighbor
|
61 |
+
neighbors.append(reshaped_outputs[:, :, i, j+1, :])
|
62 |
+
|
63 |
+
# Aggregate information from neighbors (simple average as an example)
|
64 |
+
if neighbors:
|
65 |
+
# Stack neighbors along a new dimension and take the mean
|
66 |
+
neighbor_stack = torch.stack(neighbors, dim=-2) # shape (batch, seq, num_neighbors, hidden)
|
67 |
+
aggregated_info = torch.mean(neighbor_stack, dim=-2) # shape (batch, seq, hidden)
|
68 |
+
neighbor_info = aggregated_info # Use the aggregated info
|
69 |
+
|
70 |
+
# Apply the exchange projection to the aggregated neighbor information
|
71 |
+
transformed_neighbor_info = self.exchange_projection(neighbor_info)
|
72 |
+
|
73 |
+
# Store the transformed neighbor info for the current expert's position
|
74 |
+
aggregated_neighbor_info[:, :, i, j, :] = transformed_neighbor_info
|
75 |
+
|
76 |
+
# Reshape aggregated_neighbor_info back to (batch_size, sequence_length, num_experts, hidden_size)
|
77 |
+
aggregated_neighbor_info = aggregated_neighbor_info.view(batch_size, seq_length, num_experts, hidden_size)
|
78 |
+
|
79 |
+
# Combine expert outputs with aggregated neighbor information (additive combination)
|
80 |
+
exchanged_expert_outputs = expert_outputs + aggregated_neighbor_info
|
81 |
+
|
82 |
+
return exchanged_expert_outputs
|