--- license: apache-2.0 datasets: - HuggingFaceFW/fineweb-edu - HuggingFaceH4/MATH-500 - openai/gsm8k language: - en pipeline_tag: text-generation tags: - mesh - moe - mesh-labs - alpha - preview - research - experiment - routing - innovative - innovation - mesh-moe - custom_code --- # Mesh-v0.1-2x2 (Stage 003) ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6747320df82ae35f0327cdd3/2JPwH3coASgEc4vJvJVRt.png) ## Introducing mesh This is our first ever model! Allow us to explain how the `mesh` architecture works in detail. - Neural Mesh extends the concept of Mixture of Experts by allowing bidirectional expert communication. - The experts are shared in a bidimensional grid (2x2, 4x4, etc.) layout, that allows for them to communicate with their neighbors using the "Neighbor Exchange" method. - Just like MoE models, Mesh models have dynamic routing, and through the `routing_k` parameter you can define the amount of active parameters. For this model (2x2): - top-1 routing: 173M active parameters - top-2 routing: 242M active parameters (default) - dense routing: 302M active parameters ## Here's how the mesh architecture works: ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6747320df82ae35f0327cdd3/WRpS2T5KBMPbacobfh0bw.png) ## How to load the model ```python from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig, PreTrainedModel import torch import torch.nn as nn import torch.nn.functional as F import math from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation import GenerationMixin import os class MeshConfig(PretrainedConfig): model_type = "mesh" def __init__( self, vocab_size=32000, hidden_size=768, intermediate_size=2048, num_hidden_layers=12, num_attention_heads=12, num_key_value_heads=12, max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, mesh_grid_size=(2, 2), expert_intermediate_size=256, routing_k=2, neighbor_exchange_enabled=True, cross_expert_attention_enabled=True, expert_scale_factor="sqrt_k", load_in_8bit=False, load_in_4bit=False, **kwargs ): super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, rms_norm_eps=rms_norm_eps, use_cache=use_cache, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) self.mesh_grid_size = mesh_grid_size self.expert_intermediate_size = kwargs.pop("expert_intermediate_size", intermediate_size // (mesh_grid_size[0] * mesh_grid_size[1])) self.routing_k = routing_k self.neighbor_exchange_enabled = neighbor_exchange_enabled self.cross_expert_attention_enabled = cross_expert_attention_enabled self.expert_scale_factor = expert_scale_factor self.load_in_8bit = load_in_8bit self.load_in_4bit = load_in_4bit class MeshExpert(nn.Module): def __init__(self, config: MeshConfig): super().__init__() self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size) self.gelu = nn.GELU() self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size) def forward(self, x): return self.fc2(self.gelu(self.fc1(x))) class MeshRouter(nn.Module): def __init__(self, config: MeshConfig): super().__init__() self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1]) self.softmax = nn.Softmax(dim=-1) self.routing_k = config.routing_k def forward(self, x): gate_scores = self.gate(x) gate_weights = self.softmax(gate_scores) topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1) return topk_weights, topk_indices class NeighborExchange(nn.Module): def __init__(self, config: MeshConfig): super().__init__() self.config = config self.num_experts_x = config.mesh_grid_size[0] self.num_experts_y = config.mesh_grid_size[1] self.num_experts = self.num_experts_x * self.num_experts_y self.exchange_projection = nn.Linear(config.hidden_size, config.hidden_size) def forward(self, expert_outputs, expert_indices=None): if not self.config.neighbor_exchange_enabled: return expert_outputs batch_size, seq_length, num_experts, hidden_size = expert_outputs.shape reshaped_outputs = expert_outputs.view(batch_size, seq_length, self.num_experts_x, self.num_experts_y, hidden_size) aggregated_neighbor_info = torch.zeros_like(reshaped_outputs) for i in range(self.num_experts_x): for j in range(self.num_experts_y): current_expert_output = reshaped_outputs[:, :, i, j, :] neighbor_info = torch.zeros_like(current_expert_output) neighbors = [] if i > 0: neighbors.append(reshaped_outputs[:, :, i-1, j, :]) if i < self.num_experts_x - 1: neighbors.append(reshaped_outputs[:, :, i+1, j, :]) if j > 0: neighbors.append(reshaped_outputs[:, :, i, j-1, :]) if j < self.num_experts_y - 1: neighbors.append(reshaped_outputs[:, :, i, j+1, :]) if neighbors: neighbor_stack = torch.stack(neighbors, dim=-2) aggregated_info = torch.mean(neighbor_stack, dim=-2) neighbor_info = aggregated_info transformed_neighbor_info = self.exchange_projection(neighbor_info) aggregated_neighbor_info[:, :, i, j, :] = transformed_neighbor_info aggregated_neighbor_info = aggregated_neighbor_info.view(batch_size, seq_length, num_experts, hidden_size) exchanged_expert_outputs = expert_outputs + aggregated_neighbor_info return exchanged_expert_outputs class CrossExpertAttention(nn.Module): def __init__(self, config: MeshConfig): super().__init__() self.config = config self.cross_attention = nn.MultiheadAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, batch_first=True ) def forward(self, expert_outputs): if not self.config.cross_expert_attention_enabled: return expert_outputs batch_seq_size = expert_outputs.shape[0] * expert_outputs.shape[1] reshaped_outputs = expert_outputs.view(batch_seq_size, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size) cross_attn_output, _ = self.cross_attention(reshaped_outputs, reshaped_outputs, reshaped_outputs) cross_attn_output = cross_attn_output.view( expert_outputs.shape[0], expert_outputs.shape[1], self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size ) return cross_attn_output class MeshLayer(nn.Module): def __init__(self, config: MeshConfig): super().__init__() self.config = config self.router = MeshRouter(config) self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])]) self.neighbor_exchange = NeighborExchange(config) self.cross_expert_attention = CrossExpertAttention(config) def forward(self, hidden_states): topk_weights, topk_indices = self.router(hidden_states) expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1) if self.config.expert_scale_factor == "sqrt_k": scaling_factor = math.sqrt(self.config.routing_k) scaled_expert_inputs = expanded_hidden_states * scaling_factor elif self.config.expert_scale_factor == "1_over_k": scaling_factor = 1.0 / self.config.routing_k scaled_expert_inputs = expanded_hidden_states * scaling_factor else: scaled_expert_inputs = expanded_hidden_states expert_outputs_list = [expert(scaled_expert_inputs[:, :, i, :]) for i, expert in enumerate(self.experts)] expert_outputs = torch.stack(expert_outputs_list, dim=2) exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices) cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs) gathered_outputs = torch.gather( cross_attned_expert_outputs, dim=2, index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size) ) combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2) return combined_output, topk_indices class MeshModel(PreTrainedModel, GenerationMixin): config_class = MeshConfig def __init__(self, config: MeshConfig): super().__init__(config) self.config = config self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() self._supports_gradient_checkpointing = True self.gradient_checkpointing = False def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, return_dict=None, output_attentions=None, output_hidden_states=None, past_key_values=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: inputs_embeds = self.embedding(input_ids) elif inputs_embeds is not None: pass else: raise ValueError("You have to specify either input_ids or inputs_embeds") hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: import torch.utils.checkpoint for i, layer in enumerate(self.layers): if hasattr(layer, 'forward') and callable(layer.forward): if self.gradient_checkpointing and self.training: checkpoint_output = torch.utils.checkpoint.checkpoint( layer, hidden_states, use_reentrant=False ) if isinstance(checkpoint_output, tuple): hidden_states = checkpoint_output[0] else: hidden_states = checkpoint_output else: layer_output = layer(hidden_states) hidden_states = layer_output[0] else: print(f"Warning: Layer {i} does not have a callable forward method. Skipping layer processing.") hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if return_dict: return CausalLMOutputWithPast( loss=loss, logits=logits, ) else: return (loss, logits) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if inputs_embeds is not None: inputs_embeds = inputs_embeds[:, -1, :].unsqueeze(1) if inputs_embeds is not None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} if "attention_mask" in kwargs: model_inputs["attention_mask"] = kwargs["attention_mask"] return model_inputs def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): self.gradient_checkpointing = True self.config.gradient_checkpointing = True print("Gradient checkpointing enabled on MeshModel.") def gradient_checkpointing_disable(self): self.gradient_checkpointing = False self.config.gradient_checkpointing = False print("Gradient checkpointing disabled on MeshModel.") def _set_gradient_checkpointing(self, enable=True): if enable: self.gradient_checkpointing_enable() else: self.gradient_checkpointing_disable() from transformers import AutoConfig AutoConfig.register("mesh", MeshConfig) AutoModelForCausalLM.register(MeshConfig, MeshModel) HF_MERGED_REPO_STAGE003 = "mesh-labs/v0.1-2x2-stage003" loaded_model_stage003 = None loaded_tokenizer_stage003 = None try: print(f"Attempting to load Stage 003 merged model from HF: {HF_MERGED_REPO_STAGE003}...") device_map = "auto" loaded_model_stage003 = AutoModelForCausalLM.from_pretrained( HF_MERGED_REPO_STAGE003, trust_remote_code=True, device_map=device_map, torch_dtype=torch.float32 ) if torch.cuda.is_available(): loaded_model_stage003.to('cuda') print("Stage 003 merged model moved to GPU.") else: print("Stage 003 merged model loaded on CPU.") loaded_tokenizer_stage003 = AutoTokenizer.from_pretrained( HF_MERGED_REPO_STAGE003, trust_remote_code=True, use_fast=False ) print("Stage 003 merged model and tokenizer loaded successfully from Hugging Face Hub.") except Exception as e: print(f"Error loading Stage 003 merged model or tokenizer from Hugging Face Hub: {e}") loaded_model_stage003 = None loaded_tokenizer_stage003 = None if loaded_model_stage003 is not None and loaded_tokenizer_stage003 is not None: print("\n--- Starting Chat Interface ---") print("Type your message and press Enter. Type 'quit' to exit.") loaded_model_stage003.eval() while True: try: user_input = input("You: ") if user_input.lower() == 'quit': break prompt = f"Question: {user_input}\nAnswer:" inputs = loaded_tokenizer_stage003(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to('cuda') for k, v in inputs.items()} with torch.no_grad(): outputs = loaded_model_stage003.generate( **inputs, max_new_tokens=128, num_beams=1, do_sample=False, ) generated_sequence = loaded_tokenizer_stage003.decode(outputs[0], skip_special_tokens=True) answer_prefix = "Answer:" answer_start_index = generated_sequence.find(answer_prefix) if answer_start_index != -1: generated_answer = generated_sequence[answer_start_index + len(answer_prefix):].strip() else: print("Warning: 'Answer:' prefix not found in generated text. Showing full generated sequence.") generated_answer = generated_sequence.strip() print("Model:", generated_answer) except Exception as e: print(f"An error occurred: {e}") print("Please try again or type 'quit' to exit.") else: print("\nModel or tokenizer not loaded. Cannot start chat interface.") ``` ## Disclaimer This small language model is just a proof-of-concept, paving the way to the final release, which is likely to happen in Q4 2025, and include more models and better support from external libraries such as Transformers and Llama.cpp.