🚀 Exporting Mistral-7B to ONNX for Unity Sentis 6.2
Unity 6.2 just dropped with upgraded ONNX / Sentis runtime support. But until now, most devs were told:
“You can’t export Mistral to ONNX. PyTorch’s scaled_dot_product_attention and Hugging Face’s masks make it impossible.”
We decided to prove otherwise. This article shows how to patch MistralAttention, bypass fused SDPA, and export a clean .onnx that runs in ONNX Runtime and Unity Sentis 6.2.
⚡ The Problem
Typical export fails with:
UnsupportedOperatorError: Exporting the operator 'aten::__ior_' ...
TypeError: scaled_dot_product_attention ...
ValueError: too many values to unpack ...
RuntimeError: Only tuples, lists and Variables are supported ... (DynamicCache)
Three blockers:
Hugging Face mask utils use in-place ops (ior).
PyTorch 2.x fused scaled_dot_product_attention can’t be traced.
Mistral’s GQA (grouped query attention) plus HF’s DynamicCache breaks ONNX export unless you strip the cache.
🛠️ Environment
Colab Pro+ A100, but any CUDA 12.x setup works.
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124
pip install transformers==4.53.3 optimum==1.27.0 onnx==1.16.2 onnxruntime-gpu==1.22.0 onnxruntime-tools
Torch 2.4.1 + CUDA 12.4
transformers 4.53.x
optimum 1.27.0
onnxruntime-gpu 1.22.0
🔧 Patch the Attention Layer
Replace fused SDPA with plain matmul + softmax, and handle grouped K/V heads (num_key_value_groups).
import torch
import torch.nn.functional as F
from transformers.models.mistral.modeling_mistral import MistralAttention
def export_safe_forward(self, hidden_states, attention_mask=None, **kwargs):
bsz, q_len, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
num_heads = self.config.hidden_size // self.head_dim
num_kv_heads = num_heads // self.num_key_value_groups
q = q.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, q_len, num_kv_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, q_len, num_kv_heads, self.head_dim).transpose(1, 2)
if num_kv_heads != num_heads:
repeat_factor = num_heads // num_kv_heads
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
attn_weights = torch.matmul(q, k.transpose(-1, -2)) / (self.head_dim ** 0.5)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
# Return shape expected by MistralDecoderLayer
return attn_output, None
MistralAttention.forward = export_safe_forward
print("✅ Patched MistralAttention for export")
🧰 Export to ONNX
Critical fix: strip HF’s DynamicCache by overriding the model’s forward to return only logits.
import types
from torch.onnx import OperatorExportTypes
# ensure cache is off
model.config.use_cache = False
model.eval()
# override top-level forward so only logits are returned (no DynamicCache)
def export_forward(self, input_ids, **kwargs):
outputs = self.__class__.forward(self, input_ids, use_cache=False, **kwargs)
return outputs.logits
model.forward = types.MethodType(export_forward, model)
onnx_out = "mistral_custom_clean_logits.onnx"
dummy_input_ids = torch.randint(0, model.config.vocab_size, (1, 128), dtype=torch.long).to("cuda")
print("Exporting with patched attention...")
torch.onnx.export(
model,
(dummy_input_ids,),
onnx_out,
input_names=["input_ids"],
output_names=["logits"],
opset_version=15,
do_constant_folding=True,
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"logits": {0: "batch", 1: "sequence"}
},
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH
)
print("✅ Exported to", onnx_out)
🧹 Clean the Graph (Remove ior)
If your export fails with:
UnsupportedOperatorError: Exporting the operator 'aten::__ior_'
don’t panic. Patch the .onnx file:
import onnx
path_in = "mistral_custom_clean_logits.onnx"
path_out = "mistral_custom_clean_logits_fixed.onnx"
model_onnx = onnx.load(path_in)
# Replace "__ior__" nodes with "Or" if any remain
fixed = False
for node in model_onnx.graph.node:
if node.op_type == "__ior__":
print("⚡ Fixing node:", node.name)
node.op_type = "Or"
fixed = True
onnx.save(model_onnx, path_out if fixed else path_in)
print("✅ Saved:", path_out if fixed else path_in)
Use mistral_custom_clean_logits_fixed.onnx if it was created; otherwise keep using the original.
🔬 Validate + Test Inference
Don’t use model.config.vocab_size here (HF model isn’t in scope). Either parse from ONNX or use Mistral’s vocab size (32768).
import onnx, onnxruntime as ort, numpy as np, os
# choose the fixed file if it exists
onnx_path = "mistral_custom_clean_logits_fixed.onnx"
if not os.path.exists(onnx_path):
onnx_path = "mistral_custom_clean_logits.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX validated")
# Prefer reading vocab size from model output shape; fallback to 32768
out_shape = onnx_model.graph.output[0].type.tensor_type.shape.dim
if len(out_shape) >= 3 and out_shape[2].HasField("dim_value"):
vocab_size = out_shape[2].dim_value
else:
vocab_size = 32768
print("Vocab size:", vocab_size)
session = ort.InferenceSession(onnx_path,
providers=["CUDAExecutionProvider","CPUExecutionProvider"])
dummy = np.random.randint(0, vocab_size, size=(1,16), dtype=np.int64)
logits = session.run(None, {"input_ids": dummy})[0]
print("Logits shape:", logits.shape) # (1, 16, vocab_size)
print("Sample:", logits[0, -1, :10])
🎮 Unity Sentis 6.2
Unity 6.2 supports opset ≤ 15. With this export you can:
Drop mistral_custom_clean_logits_fixed.onnx (or the unfixed one if no ior) into Assets/Models/
Load with Sentis.OnnxModel + WorkerFactory
Get full logits out, same as ONNX Runtime
✨ Conclusion
They said it couldn’t be done. By patching MistralAttention and stripping DynamicCache, we exported Mistral-7B to ONNX and validated it in ONNX Runtime — ready for Unity Sentis 6.2.
No hacks at inference time, no broken ops. Just pure matmul attention, ONNX-legal.