DeepSeek-Tiny with MLA-o V0.1
6-layer DeepSeek-V3 with MLA + shared output latent space ("MLA-o") trained for research on shared subspaces in Transformer attention mechanisms.
Model Description
- Model Type: Transformer Decoder (DeepSeek-V3 based)
- Architecture: 6-layer decoder with Mixture of Experts
- Parameters: 16.17M
- Hidden Size: 256
- Attention Heads: 8
- Head Dimension: 32
- Sequence Length: 1,024 tokens
- Query Latent Dimension: 96
- Key-Value Latent Dimension: 64
- Output Latent Dimension: 96
Performance
- SST-2 Accuracy: 86.24%
- WikiText-103 Perplexity: 29.33
Research Context
This model is part of the shared-subspaces research project investigating the impact of shared output latent spaces in Transformer attention mechanisms.
Output Subspace Decomposition
This model implements a shared output latent space where the attention output projection W^O is decomposed into:
W^O = W^OA · W^OB
Where W^OA are per-head projections to the latent space and W^OB is a shared projection back to the model dimension.
Usage
Rather than overwrite the entire attention layer, we simply patched the o_proj
parameter with a nn.Sequential
. It's an easy way to modify the model prior to pre-training, but loading the weights is a different story.
The below code applies the patch, and then loads in the necessary weights manually.
import torch
import torch.nn as nn
from transformers import DeepseekV3ForCausalLM, AutoTokenizer
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
def load_mla_o_model(repo_id="ChrisMcCormick/deepseek-tiny-mla-o-v0.1"):
"""
Load the MLA-o model with output subspace decomposition
"""
print("\n<<Ignore the 'weights not used' warning>>\n")
# Load base model (without decomposed weights)
model = DeepseekV3ForCausalLM.from_pretrained(repo_id)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
print("\nPatching weights...\n")
# Download the safetensors file to get the decomposed weights
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
weights = load_file(weights_path)
# Apply output subspace decomposition to all attention layers
for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn
# Calculate dimensions
in_features = attn.num_heads * attn.v_head_dim # 8 * 32 = 256
o_latent_dim = 96 # Output latent dimension
out_features = model.config.hidden_size # 256
bias = bool(getattr(model.config, "attention_bias", False))
# Replace o_proj with sequential decomposition
attn.o_proj = nn.Sequential(
nn.Linear(in_features, o_latent_dim, bias=False), # W^OA: 256 -> 96
nn.RMSNorm(o_latent_dim, eps=model.config.rms_norm_eps), # Normalization
nn.Linear(o_latent_dim, out_features, bias=bias), # W^OB: 96 -> 256
)
# Load the decomposed weights
layer_prefix = f"model.layers.{layer_idx}.self_attn.o_proj"
# Load W^OA weights (o_proj.0.weight)
w_oa_key = f"{layer_prefix}.0.weight"
if w_oa_key in weights:
attn.o_proj[0].weight.data = weights[w_oa_key]
# Load RMSNorm weights (o_proj.1.weight)
w_norm_key = f"{layer_prefix}.1.weight"
if w_norm_key in weights:
attn.o_proj[1].weight.data = weights[w_norm_key]
# Load W^OB weights (o_proj.2.weight)
w_ob_key = f"{layer_prefix}.2.weight"
if w_ob_key in weights:
attn.o_proj[2].weight.data = weights[w_ob_key]
# Load W^OB bias if it exists
w_ob_bias_key = f"{layer_prefix}.2.bias"
if w_ob_bias_key in weights and attn.o_proj[2].bias is not None:
attn.o_proj[2].bias.data = weights[w_ob_bias_key]
print("Model loaded and patched.")
return model, tokenizer
# Load the model
model, tokenizer = load_mla_o_model()
# Generate text
inputs = tokenizer("The future of AI is", return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=50,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
print("Generated text:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Training Details
- Pre-training Dataset: WikiText-103
- Optimizer: AdamW
- Learning Rate: 5e-4
- Weight Decay: 0.01
- Precision: bfloat16
- Compilation: torch.compile with inductor backend
- Training Steps: 12,500
- Effective Batch Size: 1,024
Limitations
- Small scale model (16M parameters) intended for research purposes
- Trained on limited data compared to production models
- Downloads last month
- 20