Mesh-v0.1-2x2 (Stage 003)

image/png

Introducing mesh

This is our first ever model! Allow us to explain how the mesh architecture works in detail.

  • Neural Mesh extends the concept of Mixture of Experts by allowing bidirectional expert communication.

  • The experts are shared in a bidimensional grid (2x2, 4x4, etc.) layout, that allows for them to communicate with their neighbors using the "Neighbor Exchange" method.

  • Just like MoE models, Mesh models have dynamic routing, and through the routing_k parameter you can define the amount of active parameters. For this model (2x2):

    • top-1 routing: 173M active parameters
    • top-2 routing: 242M active parameters (default)
    • dense routing: 302M active parameters

Here's how the mesh architecture works:

image/png

How to load the model

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation import GenerationMixin
import os

class MeshConfig(PretrainedConfig):
    model_type = "mesh"

    def __init__(
        self,
        vocab_size=32000,
        hidden_size=768,
        intermediate_size=2048,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_key_value_heads=12,
        max_position_embeddings=4096,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=0,
        bos_token_id=1,
        eos_token_id=2,
        tie_word_embeddings=False,
        mesh_grid_size=(2, 2),
        expert_intermediate_size=256,
        routing_k=2,
        neighbor_exchange_enabled=True,
        cross_expert_attention_enabled=True,
        expert_scale_factor="sqrt_k",
        load_in_8bit=False,
        load_in_4bit=False,
        **kwargs
    ):
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            max_position_embeddings=max_position_embeddings,
            initializer_range=initializer_range,
            rms_norm_eps=rms_norm_eps,
            use_cache=use_cache,
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )
        self.mesh_grid_size = mesh_grid_size
        self.expert_intermediate_size = kwargs.pop("expert_intermediate_size", intermediate_size // (mesh_grid_size[0] * mesh_grid_size[1]))
        self.routing_k = routing_k
        self.neighbor_exchange_enabled = neighbor_exchange_enabled
        self.cross_expert_attention_enabled = cross_expert_attention_enabled
        self.expert_scale_factor = expert_scale_factor
        self.load_in_8bit = load_in_8bit
        self.load_in_4bit = load_in_4bit

class MeshExpert(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size)

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))

class MeshRouter(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
        self.softmax = nn.Softmax(dim=-1)
        self.routing_k = config.routing_k

    def forward(self, x):
        gate_scores = self.gate(x)
        gate_weights = self.softmax(gate_scores)
        topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
        return topk_weights, topk_indices

class NeighborExchange(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.config = config
        self.num_experts_x = config.mesh_grid_size[0]
        self.num_experts_y = config.mesh_grid_size[1]
        self.num_experts = self.num_experts_x * self.num_experts_y

        self.exchange_projection = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, expert_outputs, expert_indices=None):
        if not self.config.neighbor_exchange_enabled:
            return expert_outputs

        batch_size, seq_length, num_experts, hidden_size = expert_outputs.shape
        reshaped_outputs = expert_outputs.view(batch_size, seq_length, self.num_experts_x, self.num_experts_y, hidden_size)
        aggregated_neighbor_info = torch.zeros_like(reshaped_outputs)

        for i in range(self.num_experts_x):
            for j in range(self.num_experts_y):
                current_expert_output = reshaped_outputs[:, :, i, j, :]
                neighbor_info = torch.zeros_like(current_expert_output)
                neighbors = []
                if i > 0: neighbors.append(reshaped_outputs[:, :, i-1, j, :])
                if i < self.num_experts_x - 1: neighbors.append(reshaped_outputs[:, :, i+1, j, :])
                if j > 0: neighbors.append(reshaped_outputs[:, :, i, j-1, :])
                if j < self.num_experts_y - 1: neighbors.append(reshaped_outputs[:, :, i, j+1, :])

                if neighbors:
                    neighbor_stack = torch.stack(neighbors, dim=-2)
                    aggregated_info = torch.mean(neighbor_stack, dim=-2)
                    neighbor_info = aggregated_info

                transformed_neighbor_info = self.exchange_projection(neighbor_info)
                aggregated_neighbor_info[:, :, i, j, :] = transformed_neighbor_info

        aggregated_neighbor_info = aggregated_neighbor_info.view(batch_size, seq_length, num_experts, hidden_size)
        exchanged_expert_outputs = expert_outputs + aggregated_neighbor_info

        return exchanged_expert_outputs

class CrossExpertAttention(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.config = config
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=config.hidden_size,
            num_heads=config.num_attention_heads,
            batch_first=True
        )

    def forward(self, expert_outputs):
        if not self.config.cross_expert_attention_enabled:
            return expert_outputs

        batch_seq_size = expert_outputs.shape[0] * expert_outputs.shape[1]
        reshaped_outputs = expert_outputs.view(batch_seq_size, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size)
        cross_attn_output, _ = self.cross_attention(reshaped_outputs, reshaped_outputs, reshaped_outputs)
        cross_attn_output = cross_attn_output.view(
            expert_outputs.shape[0], expert_outputs.shape[1], self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size
        )
        return cross_attn_output

class MeshLayer(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.config = config
        self.router = MeshRouter(config)
        self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
        self.neighbor_exchange = NeighborExchange(config)
        self.cross_expert_attention = CrossExpertAttention(config)

    def forward(self, hidden_states):
        topk_weights, topk_indices = self.router(hidden_states)
        expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)

        if self.config.expert_scale_factor == "sqrt_k":
            scaling_factor = math.sqrt(self.config.routing_k)
            scaled_expert_inputs = expanded_hidden_states * scaling_factor
        elif self.config.expert_scale_factor == "1_over_k":
            scaling_factor = 1.0 / self.config.routing_k
            scaled_expert_inputs = expanded_hidden_states * scaling_factor
        else:
            scaled_expert_inputs = expanded_hidden_states

        expert_outputs_list = [expert(scaled_expert_inputs[:, :, i, :]) for i, expert in enumerate(self.experts)]
        expert_outputs = torch.stack(expert_outputs_list, dim=2)

        exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
        cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)

        gathered_outputs = torch.gather(
            cross_attned_expert_outputs,
            dim=2,
            index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
        )

        combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)

        return combined_output, topk_indices

class MeshModel(PreTrainedModel, GenerationMixin):
    config_class = MeshConfig

    def __init__(self, config: MeshConfig):
        super().__init__(config)
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()

        self._supports_gradient_checkpointing = True
        self.gradient_checkpointing = False

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        labels=None,
        return_dict=None,
        output_attentions=None,
        output_hidden_states=None,
        past_key_values=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            inputs_embeds = self.embedding(input_ids)
        elif inputs_embeds is not None:
            pass
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
             import torch.utils.checkpoint

        for i, layer in enumerate(self.layers):
            if hasattr(layer, 'forward') and callable(layer.forward):
                 if self.gradient_checkpointing and self.training:
                      checkpoint_output = torch.utils.checkpoint.checkpoint(
                          layer, hidden_states, use_reentrant=False
                      )
                      if isinstance(checkpoint_output, tuple):
                          hidden_states = checkpoint_output[0]
                      else:
                           hidden_states = checkpoint_output

                 else:
                      layer_output = layer(hidden_states)
                      hidden_states = layer_output[0]
            else:
                 print(f"Warning: Layer {i} does not have a callable forward method. Skipping layer processing.")

        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

        if return_dict:
             return CausalLMOutputWithPast(
                 loss=loss,
                 logits=logits,
             )
        else:
             return (loss, logits)

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
        if past_key_values is not None:
             input_ids = input_ids[:, -1].unsqueeze(-1)
             if inputs_embeds is not None:
                 inputs_embeds = inputs_embeds[:, -1, :].unsqueeze(1)

        if inputs_embeds is not None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        if "attention_mask" in kwargs:
             model_inputs["attention_mask"] = kwargs["attention_mask"]

        return model_inputs

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        self.gradient_checkpointing = True
        self.config.gradient_checkpointing = True
        print("Gradient checkpointing enabled on MeshModel.")

    def gradient_checkpointing_disable(self):
         self.gradient_checkpointing = False
         self.config.gradient_checkpointing = False
         print("Gradient checkpointing disabled on MeshModel.")

    def _set_gradient_checkpointing(self, enable=True):
        if enable:
             self.gradient_checkpointing_enable()
        else:
             self.gradient_checkpointing_disable()

from transformers import AutoConfig
AutoConfig.register("mesh", MeshConfig)
AutoModelForCausalLM.register(MeshConfig, MeshModel)

HF_MERGED_REPO_STAGE003 = "mesh-labs/v0.1-2x2-stage003"

loaded_model_stage003 = None
loaded_tokenizer_stage003 = None

try:
    print(f"Attempting to load Stage 003 merged model from HF: {HF_MERGED_REPO_STAGE003}...")
    device_map = "auto"

    loaded_model_stage003 = AutoModelForCausalLM.from_pretrained(
        HF_MERGED_REPO_STAGE003,
        trust_remote_code=True,
        device_map=device_map,
        torch_dtype=torch.float32
    )

    if torch.cuda.is_available():
        loaded_model_stage003.to('cuda')
        print("Stage 003 merged model moved to GPU.")
    else:
        print("Stage 003 merged model loaded on CPU.")

    loaded_tokenizer_stage003 = AutoTokenizer.from_pretrained(
        HF_MERGED_REPO_STAGE003,
        trust_remote_code=True,
        use_fast=False
    )

    print("Stage 003 merged model and tokenizer loaded successfully from Hugging Face Hub.")

except Exception as e:
    print(f"Error loading Stage 003 merged model or tokenizer from Hugging Face Hub: {e}")
    loaded_model_stage003 = None
    loaded_tokenizer_stage003 = None

if loaded_model_stage003 is not None and loaded_tokenizer_stage003 is not None:
    print("\n--- Starting Chat Interface ---")
    print("Type your message and press Enter. Type 'quit' to exit.")

    loaded_model_stage003.eval()

    while True:
        try:
            user_input = input("You: ")
            if user_input.lower() == 'quit':
                break

            prompt = f"Question: {user_input}\nAnswer:"

            inputs = loaded_tokenizer_stage003(prompt, return_tensors="pt")

            if torch.cuda.is_available():
                inputs = {k: v.to('cuda') for k, v in inputs.items()}

            with torch.no_grad():
                outputs = loaded_model_stage003.generate(
                    **inputs,
                    max_new_tokens=128,
                    num_beams=1,
                    do_sample=False,
                )

            generated_sequence = loaded_tokenizer_stage003.decode(outputs[0], skip_special_tokens=True)

            answer_prefix = "Answer:"
            answer_start_index = generated_sequence.find(answer_prefix)

            if answer_start_index != -1:
                generated_answer = generated_sequence[answer_start_index + len(answer_prefix):].strip()
            else:
                print("Warning: 'Answer:' prefix not found in generated text. Showing full generated sequence.")
                generated_answer = generated_sequence.strip()

            print("Model:", generated_answer)

        except Exception as e:
            print(f"An error occurred: {e}")
            print("Please try again or type 'quit' to exit.")

else:
    print("\nModel or tokenizer not loaded. Cannot start chat interface.")

Disclaimer

This small language model is just a proof-of-concept, paving the way to the final release, which is likely to happen in Q4 2025, and include more models and better support from external libraries such as Transformers and Llama.cpp.

Downloads last month
25
Safetensors
Model size
420M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train mesh-labs/v0.1-2x2-stage003

Collection including mesh-labs/v0.1-2x2-stage003