🚀 Exporting Mistral-7B to ONNX for Unity Sentis 6.2

Community Article Published August 26, 2025

Unity 6.2 just dropped with upgraded ONNX / Sentis runtime support. But until now, most devs were told:

“You can’t export Mistral to ONNX. PyTorch’s scaled_dot_product_attention and Hugging Face’s masks make it impossible.”

We decided to prove otherwise. This article shows how to patch MistralAttention, bypass fused SDPA, and export a clean .onnx that runs in ONNX Runtime and Unity Sentis 6.2.

⚡ The Problem

Typical export fails with:

UnsupportedOperatorError: Exporting the operator 'aten::__ior_' ...
TypeError: scaled_dot_product_attention ...
ValueError: too many values to unpack ...
RuntimeError: Only tuples, lists and Variables are supported ... (DynamicCache)

Three blockers:

Hugging Face mask utils use in-place ops (ior).

PyTorch 2.x fused scaled_dot_product_attention can’t be traced.

Mistral’s GQA (grouped query attention) plus HF’s DynamicCache breaks ONNX export unless you strip the cache.

🛠️ Environment

Colab Pro+ A100, but any CUDA 12.x setup works.

pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124
pip install transformers==4.53.3 optimum==1.27.0 onnx==1.16.2 onnxruntime-gpu==1.22.0 onnxruntime-tools

Torch 2.4.1 + CUDA 12.4

transformers 4.53.x

optimum 1.27.0

onnxruntime-gpu 1.22.0

🔧 Patch the Attention Layer

Replace fused SDPA with plain matmul + softmax, and handle grouped K/V heads (num_key_value_groups).

import torch
import torch.nn.functional as F
from transformers.models.mistral.modeling_mistral import MistralAttention

def export_safe_forward(self, hidden_states, attention_mask=None, **kwargs):
    bsz, q_len, _ = hidden_states.size()

    q = self.q_proj(hidden_states)
    k = self.k_proj(hidden_states)
    v = self.v_proj(hidden_states)

    num_heads = self.config.hidden_size // self.head_dim
    num_kv_heads = num_heads // self.num_key_value_groups

    q = q.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
    k = k.view(bsz, q_len, num_kv_heads, self.head_dim).transpose(1, 2)
    v = v.view(bsz, q_len, num_kv_heads, self.head_dim).transpose(1, 2)

    if num_kv_heads != num_heads:
        repeat_factor = num_heads // num_kv_heads
        k = k.repeat_interleave(repeat_factor, dim=1)
        v = v.repeat_interleave(repeat_factor, dim=1)

    attn_weights = torch.matmul(q, k.transpose(-1, -2)) / (self.head_dim ** 0.5)
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask
    attn_weights = F.softmax(attn_weights, dim=-1)

    attn_output = torch.matmul(attn_weights, v)
    attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
    attn_output = self.o_proj(attn_output)

    # Return shape expected by MistralDecoderLayer
    return attn_output, None

MistralAttention.forward = export_safe_forward
print("✅ Patched MistralAttention for export")

🧰 Export to ONNX

Critical fix: strip HF’s DynamicCache by overriding the model’s forward to return only logits.

import types
from torch.onnx import OperatorExportTypes

# ensure cache is off
model.config.use_cache = False
model.eval()

# override top-level forward so only logits are returned (no DynamicCache)
def export_forward(self, input_ids, **kwargs):
    outputs = self.__class__.forward(self, input_ids, use_cache=False, **kwargs)
    return outputs.logits

model.forward = types.MethodType(export_forward, model)

onnx_out = "mistral_custom_clean_logits.onnx"
dummy_input_ids = torch.randint(0, model.config.vocab_size, (1, 128), dtype=torch.long).to("cuda")

print("Exporting with patched attention...")
torch.onnx.export(
    model,
    (dummy_input_ids,),
    onnx_out,
    input_names=["input_ids"],
    output_names=["logits"],
    opset_version=15,
    do_constant_folding=True,
    dynamic_axes={
        "input_ids": {0: "batch", 1: "sequence"},
        "logits": {0: "batch", 1: "sequence"}
    },
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH
)
print("✅ Exported to", onnx_out)

🧹 Clean the Graph (Remove ior)

If your export fails with:

UnsupportedOperatorError: Exporting the operator 'aten::__ior_'

don’t panic. Patch the .onnx file:

import onnx

path_in  = "mistral_custom_clean_logits.onnx"
path_out = "mistral_custom_clean_logits_fixed.onnx"

model_onnx = onnx.load(path_in)

# Replace "__ior__" nodes with "Or" if any remain
fixed = False
for node in model_onnx.graph.node:
    if node.op_type == "__ior__":
        print("⚡ Fixing node:", node.name)
        node.op_type = "Or"
        fixed = True

onnx.save(model_onnx, path_out if fixed else path_in)
print("✅ Saved:", path_out if fixed else path_in)

Use mistral_custom_clean_logits_fixed.onnx if it was created; otherwise keep using the original.

🔬 Validate + Test Inference

Don’t use model.config.vocab_size here (HF model isn’t in scope). Either parse from ONNX or use Mistral’s vocab size (32768).

import onnx, onnxruntime as ort, numpy as np, os

# choose the fixed file if it exists
onnx_path = "mistral_custom_clean_logits_fixed.onnx"
if not os.path.exists(onnx_path):
    onnx_path = "mistral_custom_clean_logits.onnx"

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX validated")

# Prefer reading vocab size from model output shape; fallback to 32768
out_shape = onnx_model.graph.output[0].type.tensor_type.shape.dim
if len(out_shape) >= 3 and out_shape[2].HasField("dim_value"):
    vocab_size = out_shape[2].dim_value
else:
    vocab_size = 32768
print("Vocab size:", vocab_size)

session = ort.InferenceSession(onnx_path,
    providers=["CUDAExecutionProvider","CPUExecutionProvider"])

dummy = np.random.randint(0, vocab_size, size=(1,16), dtype=np.int64)
logits = session.run(None, {"input_ids": dummy})[0]

print("Logits shape:", logits.shape)  # (1, 16, vocab_size)
print("Sample:", logits[0, -1, :10])

🎮 Unity Sentis 6.2

Unity 6.2 supports opset ≤ 15. With this export you can:

Drop mistral_custom_clean_logits_fixed.onnx (or the unfixed one if no ior) into Assets/Models/

Load with Sentis.OnnxModel + WorkerFactory

Get full logits out, same as ONNX Runtime

✨ Conclusion

They said it couldn’t be done. By patching MistralAttention and stripping DynamicCache, we exported Mistral-7B to ONNX and validated it in ONNX Runtime — ready for Unity Sentis 6.2.

No hacks at inference time, no broken ops. Just pure matmul attention, ONNX-legal.

Community

Sign up or log in to comment