Mesh-v0.1-2x2 (Stage 003)
Introducing mesh
This is our first ever model! Allow us to explain how the mesh
architecture works in detail.
Neural Mesh extends the concept of Mixture of Experts by allowing bidirectional expert communication.
The experts are shared in a bidimensional grid (2x2, 4x4, etc.) layout, that allows for them to communicate with their neighbors using the "Neighbor Exchange" method.
Just like MoE models, Mesh models have dynamic routing, and through the
routing_k
parameter you can define the amount of active parameters. For this model (2x2):- top-1 routing: 173M active parameters
- top-2 routing: 242M active parameters (default)
- dense routing: 302M active parameters
Here's how the mesh architecture works:
How to load the model
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation import GenerationMixin
import os
class MeshConfig(PretrainedConfig):
model_type = "mesh"
def __init__(
self,
vocab_size=32000,
hidden_size=768,
intermediate_size=2048,
num_hidden_layers=12,
num_attention_heads=12,
num_key_value_heads=12,
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
mesh_grid_size=(2, 2),
expert_intermediate_size=256,
routing_k=2,
neighbor_exchange_enabled=True,
cross_expert_attention_enabled=True,
expert_scale_factor="sqrt_k",
load_in_8bit=False,
load_in_4bit=False,
**kwargs
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.mesh_grid_size = mesh_grid_size
self.expert_intermediate_size = kwargs.pop("expert_intermediate_size", intermediate_size // (mesh_grid_size[0] * mesh_grid_size[1]))
self.routing_k = routing_k
self.neighbor_exchange_enabled = neighbor_exchange_enabled
self.cross_expert_attention_enabled = cross_expert_attention_enabled
self.expert_scale_factor = expert_scale_factor
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
class MeshExpert(nn.Module):
def __init__(self, config: MeshConfig):
super().__init__()
self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size)
def forward(self, x):
return self.fc2(self.gelu(self.fc1(x)))
class MeshRouter(nn.Module):
def __init__(self, config: MeshConfig):
super().__init__()
self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
self.softmax = nn.Softmax(dim=-1)
self.routing_k = config.routing_k
def forward(self, x):
gate_scores = self.gate(x)
gate_weights = self.softmax(gate_scores)
topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
return topk_weights, topk_indices
class NeighborExchange(nn.Module):
def __init__(self, config: MeshConfig):
super().__init__()
self.config = config
self.num_experts_x = config.mesh_grid_size[0]
self.num_experts_y = config.mesh_grid_size[1]
self.num_experts = self.num_experts_x * self.num_experts_y
self.exchange_projection = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, expert_outputs, expert_indices=None):
if not self.config.neighbor_exchange_enabled:
return expert_outputs
batch_size, seq_length, num_experts, hidden_size = expert_outputs.shape
reshaped_outputs = expert_outputs.view(batch_size, seq_length, self.num_experts_x, self.num_experts_y, hidden_size)
aggregated_neighbor_info = torch.zeros_like(reshaped_outputs)
for i in range(self.num_experts_x):
for j in range(self.num_experts_y):
current_expert_output = reshaped_outputs[:, :, i, j, :]
neighbor_info = torch.zeros_like(current_expert_output)
neighbors = []
if i > 0: neighbors.append(reshaped_outputs[:, :, i-1, j, :])
if i < self.num_experts_x - 1: neighbors.append(reshaped_outputs[:, :, i+1, j, :])
if j > 0: neighbors.append(reshaped_outputs[:, :, i, j-1, :])
if j < self.num_experts_y - 1: neighbors.append(reshaped_outputs[:, :, i, j+1, :])
if neighbors:
neighbor_stack = torch.stack(neighbors, dim=-2)
aggregated_info = torch.mean(neighbor_stack, dim=-2)
neighbor_info = aggregated_info
transformed_neighbor_info = self.exchange_projection(neighbor_info)
aggregated_neighbor_info[:, :, i, j, :] = transformed_neighbor_info
aggregated_neighbor_info = aggregated_neighbor_info.view(batch_size, seq_length, num_experts, hidden_size)
exchanged_expert_outputs = expert_outputs + aggregated_neighbor_info
return exchanged_expert_outputs
class CrossExpertAttention(nn.Module):
def __init__(self, config: MeshConfig):
super().__init__()
self.config = config
self.cross_attention = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
batch_first=True
)
def forward(self, expert_outputs):
if not self.config.cross_expert_attention_enabled:
return expert_outputs
batch_seq_size = expert_outputs.shape[0] * expert_outputs.shape[1]
reshaped_outputs = expert_outputs.view(batch_seq_size, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size)
cross_attn_output, _ = self.cross_attention(reshaped_outputs, reshaped_outputs, reshaped_outputs)
cross_attn_output = cross_attn_output.view(
expert_outputs.shape[0], expert_outputs.shape[1], self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size
)
return cross_attn_output
class MeshLayer(nn.Module):
def __init__(self, config: MeshConfig):
super().__init__()
self.config = config
self.router = MeshRouter(config)
self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
self.neighbor_exchange = NeighborExchange(config)
self.cross_expert_attention = CrossExpertAttention(config)
def forward(self, hidden_states):
topk_weights, topk_indices = self.router(hidden_states)
expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)
if self.config.expert_scale_factor == "sqrt_k":
scaling_factor = math.sqrt(self.config.routing_k)
scaled_expert_inputs = expanded_hidden_states * scaling_factor
elif self.config.expert_scale_factor == "1_over_k":
scaling_factor = 1.0 / self.config.routing_k
scaled_expert_inputs = expanded_hidden_states * scaling_factor
else:
scaled_expert_inputs = expanded_hidden_states
expert_outputs_list = [expert(scaled_expert_inputs[:, :, i, :]) for i, expert in enumerate(self.experts)]
expert_outputs = torch.stack(expert_outputs_list, dim=2)
exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)
gathered_outputs = torch.gather(
cross_attned_expert_outputs,
dim=2,
index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
)
combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)
return combined_output, topk_indices
class MeshModel(PreTrainedModel, GenerationMixin):
config_class = MeshConfig
def __init__(self, config: MeshConfig):
super().__init__(config)
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
self._supports_gradient_checkpointing = True
self.gradient_checkpointing = False
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
labels=None,
return_dict=None,
output_attentions=None,
output_hidden_states=None,
past_key_values=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
inputs_embeds = self.embedding(input_ids)
elif inputs_embeds is not None:
pass
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
import torch.utils.checkpoint
for i, layer in enumerate(self.layers):
if hasattr(layer, 'forward') and callable(layer.forward):
if self.gradient_checkpointing and self.training:
checkpoint_output = torch.utils.checkpoint.checkpoint(
layer, hidden_states, use_reentrant=False
)
if isinstance(checkpoint_output, tuple):
hidden_states = checkpoint_output[0]
else:
hidden_states = checkpoint_output
else:
layer_output = layer(hidden_states)
hidden_states = layer_output[0]
else:
print(f"Warning: Layer {i} does not have a callable forward method. Skipping layer processing.")
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if return_dict:
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
)
else:
return (loss, logits)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1, :].unsqueeze(1)
if inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if "attention_mask" in kwargs:
model_inputs["attention_mask"] = kwargs["attention_mask"]
return model_inputs
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
self.gradient_checkpointing = True
self.config.gradient_checkpointing = True
print("Gradient checkpointing enabled on MeshModel.")
def gradient_checkpointing_disable(self):
self.gradient_checkpointing = False
self.config.gradient_checkpointing = False
print("Gradient checkpointing disabled on MeshModel.")
def _set_gradient_checkpointing(self, enable=True):
if enable:
self.gradient_checkpointing_enable()
else:
self.gradient_checkpointing_disable()
from transformers import AutoConfig
AutoConfig.register("mesh", MeshConfig)
AutoModelForCausalLM.register(MeshConfig, MeshModel)
HF_MERGED_REPO_STAGE003 = "mesh-labs/v0.1-2x2-stage003"
loaded_model_stage003 = None
loaded_tokenizer_stage003 = None
try:
print(f"Attempting to load Stage 003 merged model from HF: {HF_MERGED_REPO_STAGE003}...")
device_map = "auto"
loaded_model_stage003 = AutoModelForCausalLM.from_pretrained(
HF_MERGED_REPO_STAGE003,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch.float32
)
if torch.cuda.is_available():
loaded_model_stage003.to('cuda')
print("Stage 003 merged model moved to GPU.")
else:
print("Stage 003 merged model loaded on CPU.")
loaded_tokenizer_stage003 = AutoTokenizer.from_pretrained(
HF_MERGED_REPO_STAGE003,
trust_remote_code=True,
use_fast=False
)
print("Stage 003 merged model and tokenizer loaded successfully from Hugging Face Hub.")
except Exception as e:
print(f"Error loading Stage 003 merged model or tokenizer from Hugging Face Hub: {e}")
loaded_model_stage003 = None
loaded_tokenizer_stage003 = None
if loaded_model_stage003 is not None and loaded_tokenizer_stage003 is not None:
print("\n--- Starting Chat Interface ---")
print("Type your message and press Enter. Type 'quit' to exit.")
loaded_model_stage003.eval()
while True:
try:
user_input = input("You: ")
if user_input.lower() == 'quit':
break
prompt = f"Question: {user_input}\nAnswer:"
inputs = loaded_tokenizer_stage003(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()}
with torch.no_grad():
outputs = loaded_model_stage003.generate(
**inputs,
max_new_tokens=128,
num_beams=1,
do_sample=False,
)
generated_sequence = loaded_tokenizer_stage003.decode(outputs[0], skip_special_tokens=True)
answer_prefix = "Answer:"
answer_start_index = generated_sequence.find(answer_prefix)
if answer_start_index != -1:
generated_answer = generated_sequence[answer_start_index + len(answer_prefix):].strip()
else:
print("Warning: 'Answer:' prefix not found in generated text. Showing full generated sequence.")
generated_answer = generated_sequence.strip()
print("Model:", generated_answer)
except Exception as e:
print(f"An error occurred: {e}")
print("Please try again or type 'quit' to exit.")
else:
print("\nModel or tokenizer not loaded. Cannot start chat interface.")
Disclaimer
This small language model is just a proof-of-concept, paving the way to the final release, which is likely to happen in Q4 2025, and include more models and better support from external libraries such as Transformers and Llama.cpp.
- Downloads last month
- 25