oooo ooo .oooo. ooo. .oo. .oo. .ooooo. .ooooo.
`88. .8' `P )88b `888P"Y88bP"Y88b d88' `88b d88' `88b
`88..8' .oP"888 888 888 888 888 888 888ooo888
`888' d8( 888 888 888 888 888 888 888 .o
.8' `Y888""8o o888o o888o o888o `Y8bod8P' `Y8bod8P'
.o..P'
`Y8P'
Yet Another Mixture of Experts
yamoe
is a no nonsense, straightforward implementation of Mixture of Experts (MoE) kernels, designed to be super easy to use and be very computationally efficient.
Design goals
- simplicity: easy to read and understand the code
- efficiency: optimized for high throughput and low latency
- low memory usage: optimized to handle large batch sizes
- reproducibility: easy to reproduce results, no special new
sm
requirements - availability: easy to install and use via the kernels library
Kernel Hub
You can find the kernel on Kernel Hub and install it via the kernels library.
from kernels import get_kernel
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
Performance
yamoe
scales well as batch sizes increase in comparision to the naive method of repeating the data and computation for every item in the batch as shown in the reference in torch-ext/yamoe/reference.py. This bench can be reproduced by running uv run perf_plot.py
or a smaller bench and correctness comparision can be run with uv run compare_example.py
TLDR: smaller is better on the first two rows of charts
How to use
# /// script
# requires-python = "==3.10"
# dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
# [tool.uv.sources]
# kernels = { git = "https://github.com/huggingface/kernels.git" }
# ///
import time
import torch
from kernels import get_local_kernel
from kernels import get_kernel
from pathlib import Path
from torch.nn import functional as F
# Set seeds and deterministic flags for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
# Configuration
batch_size, seq_len, hidden_dim = 16, 256, 2880
num_experts, top_k = 8, 2
# Create routing weights
logits = torch.randn(batch_size, seq_len, num_experts)
probs = F.softmax(logits, dim=-1)
weights, indices = torch.topk(probs, top_k, dim=-1)
batch_seq = batch_size * seq_len
routing_weights = torch.zeros(batch_seq, num_experts, dtype=weights.dtype)
flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top_k)
batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
routing_weights[batch_indices, flat_indices] = flat_weights
# Create model tensors
hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda()
gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda()
gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda()
down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda()
down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda()
routing_weights = routing_weights.cuda()
router_indices = flat_indices.cuda()
# Warmup
for _ in range(5):
_ = yamoe.experts(
hidden_states.view(-1, hidden_dim),
router_indices,
routing_weights.view(-1, num_experts),
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
seq_len,
num_experts,
top_k,
)
# Benchmark
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start = time.perf_counter()
with torch.no_grad():
output = yamoe.experts(
hidden_states.view(-1, hidden_dim),
router_indices,
routing_weights.view(-1, num_experts),
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
seq_len,
num_experts,
top_k,
)
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) * 1e3
peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
print(f"Output: sum={output.sum().item():.1f}, min={output.min().item():.1f}, max={output.max().item():.1f}")
print(f"First 3: {output.view(-1)[:3].tolist()}")
print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support