--- license: mit tags: - kernel --- # triton-kernels triton-kernels is a set of kernels that enable fast moe on different architectures. These kernels are compatible with different precision (e.g bf16, mxfp4) Original code here https://github.com/triton-lang/triton/tree/main/python/triton_kernels The current version is the following commit 7d0efaa7231661299284a603512fce4fa255e62c Note that we can't update those kernels as we wish as some commits might rely on triton main. We need to wait for a new release unfortunately. See releated issue https://github.com/triton-lang/triton/issues/7818 ## Quickstart ```bash uv run https://huggingface.co/kernels-community/triton_kernels/raw/main/readme_example.py ``` ```python # /// script # requires-python = ">=3.10" # dependencies = [ # "torch", # "triton", # "numpy", # "kernels", # ] # /// import torch import sys from kernels import get_kernel torch.manual_seed(42) torch.cuda.manual_seed(42) # Load triton_kernels module via kernels library triton_kernels = get_kernel("kernels-community/triton_kernels") # Access modules directly from the loaded kernel swiglu = triton_kernels.swiglu routing = triton_kernels.routing # Setup device = "cuda" if torch.cuda.is_available() else "cpu" # SwiGLU example x = torch.randn(512, 1024, device=device, dtype=torch.bfloat16) y = swiglu.swiglu_torch(x, 0.5, swiglu.PrecisionConfig(limit=1.0)) print(f"SwiGLU: {x.shape} -> {y.shape}") # Routing example logits = torch.randn(128, 8, device=device, dtype=torch.float16) routing_data, gather_idx, scatter_idx = routing.routing_torch(logits, n_expts_act=2) print(f"Routing: {routing_data.expt_hist.sum()} tokens routed") # MoE integrated n_tokens = routing_data.expt_hist.sum().item() x_moe = torch.randn(n_tokens, 512, device=device, dtype=torch.bfloat16) y_moe = swiglu.swiglu_torch(x_moe, 0.5, swiglu.PrecisionConfig(limit=1.0)) print(f"MoE SwiGLU: {x_moe.shape} -> {y_moe.shape}") ```