v0.1-2x2-stage003 / README.md
aquiffoo's picture
Update README.md
c58e471 verified
metadata
license: apache-2.0
datasets:
  - HuggingFaceFW/fineweb-edu
  - HuggingFaceH4/MATH-500
  - openai/gsm8k
language:
  - en
pipeline_tag: text-generation
tags:
  - mesh
  - moe
  - mesh-labs
  - alpha
  - preview
  - research
  - experiment
  - routing
  - innovative
  - innovation
  - mesh-moe
  - custom_code

Mesh-v0.1-2x2 (Stage 003)

image/png

Introducing mesh

This is our first ever model! Allow us to explain how the mesh architecture works in detail.

  • Neural Mesh extends the concept of Mixture of Experts by allowing bidirectional expert communication.

  • The experts are shared in a bidimensional grid (2x2, 4x4, etc.) layout, that allows for them to communicate with their neighbors using the "Neighbor Exchange" method.

  • Just like MoE models, Mesh models have dynamic routing, and through the routing_k parameter you can define the amount of active parameters. For this model (2x2):

    • top-1 routing: 173M active parameters
    • top-2 routing: 242M active parameters (default)
    • dense routing: 302M active parameters

Here's how the mesh architecture works:

image/png

How to load the model

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation import GenerationMixin
import os

class MeshConfig(PretrainedConfig):
    model_type = "mesh"

    def __init__(
        self,
        vocab_size=32000,
        hidden_size=768,
        intermediate_size=2048,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_key_value_heads=12,
        max_position_embeddings=4096,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=0,
        bos_token_id=1,
        eos_token_id=2,
        tie_word_embeddings=False,
        mesh_grid_size=(2, 2),
        expert_intermediate_size=256,
        routing_k=2,
        neighbor_exchange_enabled=True,
        cross_expert_attention_enabled=True,
        expert_scale_factor="sqrt_k",
        load_in_8bit=False,
        load_in_4bit=False,
        **kwargs
    ):
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            max_position_embeddings=max_position_embeddings,
            initializer_range=initializer_range,
            rms_norm_eps=rms_norm_eps,
            use_cache=use_cache,
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )
        self.mesh_grid_size = mesh_grid_size
        self.expert_intermediate_size = kwargs.pop("expert_intermediate_size", intermediate_size // (mesh_grid_size[0] * mesh_grid_size[1]))
        self.routing_k = routing_k
        self.neighbor_exchange_enabled = neighbor_exchange_enabled
        self.cross_expert_attention_enabled = cross_expert_attention_enabled
        self.expert_scale_factor = expert_scale_factor
        self.load_in_8bit = load_in_8bit
        self.load_in_4bit = load_in_4bit

class MeshExpert(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size)

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))

class MeshRouter(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
        self.softmax = nn.Softmax(dim=-1)
        self.routing_k = config.routing_k

    def forward(self, x):
        gate_scores = self.gate(x)
        gate_weights = self.softmax(gate_scores)
        topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
        return topk_weights, topk_indices

class NeighborExchange(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.config = config
        self.num_experts_x = config.mesh_grid_size[0]
        self.num_experts_y = config.mesh_grid_size[1]
        self.num_experts = self.num_experts_x * self.num_experts_y

        self.exchange_projection = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, expert_outputs, expert_indices=None):
        if not self.config.neighbor_exchange_enabled:
            return expert_outputs

        batch_size, seq_length, num_experts, hidden_size = expert_outputs.shape
        reshaped_outputs = expert_outputs.view(batch_size, seq_length, self.num_experts_x, self.num_experts_y, hidden_size)
        aggregated_neighbor_info = torch.zeros_like(reshaped_outputs)

        for i in range(self.num_experts_x):
            for j in range(self.num_experts_y):
                current_expert_output = reshaped_outputs[:, :, i, j, :]
                neighbor_info = torch.zeros_like(current_expert_output)
                neighbors = []
                if i > 0: neighbors.append(reshaped_outputs[:, :, i-1, j, :])
                if i < self.num_experts_x - 1: neighbors.append(reshaped_outputs[:, :, i+1, j, :])
                if j > 0: neighbors.append(reshaped_outputs[:, :, i, j-1, :])
                if j < self.num_experts_y - 1: neighbors.append(reshaped_outputs[:, :, i, j+1, :])

                if neighbors:
                    neighbor_stack = torch.stack(neighbors, dim=-2)
                    aggregated_info = torch.mean(neighbor_stack, dim=-2)
                    neighbor_info = aggregated_info

                transformed_neighbor_info = self.exchange_projection(neighbor_info)
                aggregated_neighbor_info[:, :, i, j, :] = transformed_neighbor_info

        aggregated_neighbor_info = aggregated_neighbor_info.view(batch_size, seq_length, num_experts, hidden_size)
        exchanged_expert_outputs = expert_outputs + aggregated_neighbor_info

        return exchanged_expert_outputs

class CrossExpertAttention(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.config = config
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=config.hidden_size,
            num_heads=config.num_attention_heads,
            batch_first=True
        )

    def forward(self, expert_outputs):
        if not self.config.cross_expert_attention_enabled:
            return expert_outputs

        batch_seq_size = expert_outputs.shape[0] * expert_outputs.shape[1]
        reshaped_outputs = expert_outputs.view(batch_seq_size, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size)
        cross_attn_output, _ = self.cross_attention(reshaped_outputs, reshaped_outputs, reshaped_outputs)
        cross_attn_output = cross_attn_output.view(
            expert_outputs.shape[0], expert_outputs.shape[1], self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], self.config.hidden_size
        )
        return cross_attn_output

class MeshLayer(nn.Module):
    def __init__(self, config: MeshConfig):
        super().__init__()
        self.config = config
        self.router = MeshRouter(config)
        self.experts = nn.ModuleList([MeshExpert(config) for _ in range(config.mesh_grid_size[0] * config.mesh_grid_size[1])])
        self.neighbor_exchange = NeighborExchange(config)
        self.cross_expert_attention = CrossExpertAttention(config)

    def forward(self, hidden_states):
        topk_weights, topk_indices = self.router(hidden_states)
        expanded_hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.config.mesh_grid_size[0] * self.config.mesh_grid_size[1], -1)

        if self.config.expert_scale_factor == "sqrt_k":
            scaling_factor = math.sqrt(self.config.routing_k)
            scaled_expert_inputs = expanded_hidden_states * scaling_factor
        elif self.config.expert_scale_factor == "1_over_k":
            scaling_factor = 1.0 / self.config.routing_k
            scaled_expert_inputs = expanded_hidden_states * scaling_factor
        else:
            scaled_expert_inputs = expanded_hidden_states

        expert_outputs_list = [expert(scaled_expert_inputs[:, :, i, :]) for i, expert in enumerate(self.experts)]
        expert_outputs = torch.stack(expert_outputs_list, dim=2)

        exchanged_expert_outputs = self.neighbor_exchange(expert_outputs, topk_indices)
        cross_attned_expert_outputs = self.cross_expert_attention(exchanged_expert_outputs)

        gathered_outputs = torch.gather(
            cross_attned_expert_outputs,
            dim=2,
            index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.config.hidden_size)
        )

        combined_output = (gathered_outputs * topk_weights.unsqueeze(-1)).sum(dim=2)

        return combined_output, topk_indices

class MeshModel(PreTrainedModel, GenerationMixin):
    config_class = MeshConfig

    def __init__(self, config: MeshConfig):
        super().__init__(config)
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([MeshLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()

        self._supports_gradient_checkpointing = True
        self.gradient_checkpointing = False

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        labels=None,
        return_dict=None,
        output_attentions=None,
        output_hidden_states=None,
        past_key_values=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            inputs_embeds = self.embedding(input_ids)
        elif inputs_embeds is not None:
            pass
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
             import torch.utils.checkpoint

        for i, layer in enumerate(self.layers):
            if hasattr(layer, 'forward') and callable(layer.forward):
                 if self.gradient_checkpointing and self.training:
                      checkpoint_output = torch.utils.checkpoint.checkpoint(
                          layer, hidden_states, use_reentrant=False
                      )
                      if isinstance(checkpoint_output, tuple):
                          hidden_states = checkpoint_output[0]
                      else:
                           hidden_states = checkpoint_output

                 else:
                      layer_output = layer(hidden_states)
                      hidden_states = layer_output[0]
            else:
                 print(f"Warning: Layer {i} does not have a callable forward method. Skipping layer processing.")

        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

        if return_dict:
             return CausalLMOutputWithPast(
                 loss=loss,
                 logits=logits,
             )
        else:
             return (loss, logits)

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
        if past_key_values is not None:
             input_ids = input_ids[:, -1].unsqueeze(-1)
             if inputs_embeds is not None:
                 inputs_embeds = inputs_embeds[:, -1, :].unsqueeze(1)

        if inputs_embeds is not None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        if "attention_mask" in kwargs:
             model_inputs["attention_mask"] = kwargs["attention_mask"]

        return model_inputs

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        self.gradient_checkpointing = True
        self.config.gradient_checkpointing = True
        print("Gradient checkpointing enabled on MeshModel.")

    def gradient_checkpointing_disable(self):
         self.gradient_checkpointing = False
         self.config.gradient_checkpointing = False
         print("Gradient checkpointing disabled on MeshModel.")

    def _set_gradient_checkpointing(self, enable=True):
        if enable:
             self.gradient_checkpointing_enable()
        else:
             self.gradient_checkpointing_disable()

from transformers import AutoConfig
AutoConfig.register("mesh", MeshConfig)
AutoModelForCausalLM.register(MeshConfig, MeshModel)

HF_MERGED_REPO_STAGE003 = "mesh-labs/v0.1-2x2-stage003"

loaded_model_stage003 = None
loaded_tokenizer_stage003 = None

try:
    print(f"Attempting to load Stage 003 merged model from HF: {HF_MERGED_REPO_STAGE003}...")
    device_map = "auto"

    loaded_model_stage003 = AutoModelForCausalLM.from_pretrained(
        HF_MERGED_REPO_STAGE003,
        trust_remote_code=True,
        device_map=device_map,
        torch_dtype=torch.float32
    )

    if torch.cuda.is_available():
        loaded_model_stage003.to('cuda')
        print("Stage 003 merged model moved to GPU.")
    else:
        print("Stage 003 merged model loaded on CPU.")

    loaded_tokenizer_stage003 = AutoTokenizer.from_pretrained(
        HF_MERGED_REPO_STAGE003,
        trust_remote_code=True,
        use_fast=False
    )

    print("Stage 003 merged model and tokenizer loaded successfully from Hugging Face Hub.")

except Exception as e:
    print(f"Error loading Stage 003 merged model or tokenizer from Hugging Face Hub: {e}")
    loaded_model_stage003 = None
    loaded_tokenizer_stage003 = None

if loaded_model_stage003 is not None and loaded_tokenizer_stage003 is not None:
    print("\n--- Starting Chat Interface ---")
    print("Type your message and press Enter. Type 'quit' to exit.")

    loaded_model_stage003.eval()

    while True:
        try:
            user_input = input("You: ")
            if user_input.lower() == 'quit':
                break

            prompt = f"Question: {user_input}\nAnswer:"

            inputs = loaded_tokenizer_stage003(prompt, return_tensors="pt")

            if torch.cuda.is_available():
                inputs = {k: v.to('cuda') for k, v in inputs.items()}

            with torch.no_grad():
                outputs = loaded_model_stage003.generate(
                    **inputs,
                    max_new_tokens=128,
                    num_beams=1,
                    do_sample=False,
                )

            generated_sequence = loaded_tokenizer_stage003.decode(outputs[0], skip_special_tokens=True)

            answer_prefix = "Answer:"
            answer_start_index = generated_sequence.find(answer_prefix)

            if answer_start_index != -1:
                generated_answer = generated_sequence[answer_start_index + len(answer_prefix):].strip()
            else:
                print("Warning: 'Answer:' prefix not found in generated text. Showing full generated sequence.")
                generated_answer = generated_sequence.strip()

            print("Model:", generated_answer)

        except Exception as e:
            print(f"An error occurred: {e}")
            print("Please try again or type 'quit' to exit.")

else:
    print("\nModel or tokenizer not loaded. Cannot start chat interface.")

Disclaimer

This small language model is just a proof-of-concept, paving the way to the final release, which is likely to happen in Q4 2025, and include more models and better support from external libraries such as Transformers and Llama.cpp.