vietnamese-embedding-onnx / conver_to_onnx.py
Manh Lai
init
67a897e
raw
history blame
1.99 kB
from pathlib import Path
import onnx
from onnxconverter_common import float16
from onnxruntime.quantization import quantize_dynamic, QuantType
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
# Set model name and output directory
model_name = "dangvantuan/vietnamese-embedding"
output_dir = Path("onnx")
output_dir.mkdir(parents=True, exist_ok=True)
# -------------------------------------------
# Step 1: Export the model to ONNX (FP32)
# -------------------------------------------
print("Exporting the FP32 model...")
model = ORTModelForFeatureExtraction.from_pretrained(model_name, export=True)
model.save_pretrained(output_dir)
# Save the tokenizer alongside the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(Path("."))
# Define FP32 model path
model_fp32_path = output_dir / "model.onnx"
# -------------------------------------------
# Step 2: Convert FP32 model to FP16
# -------------------------------------------
print("Converting to FP16...")
model_fp16_path = output_dir / "model-fp16.onnx"
# Load the FP32 ONNX model
model_fp32 = onnx.load(model_fp32_path.as_posix())
# Convert weights to FP16 while keeping input/output types in FP32 if needed
model_fp16 = float16.convert_float_to_float16(model_fp32, keep_io_types=True)
# Save the FP16 model
onnx.save(model_fp16, model_fp16_path.as_posix())
# -------------------------------------------
# Step 3: Convert FP32 model to INT8 (Dynamic Quantization)
# -------------------------------------------
print("Converting to INT8 (dynamic quantization)...")
model_int8_path = output_dir / "model-int8.onnx"
quantize_dynamic(
model_fp32_path.as_posix(),
model_int8_path.as_posix(),
weight_type=QuantType.QInt8 # Use QInt8 or QUInt8 depending on your requirements
)
print("✅ Model conversion complete!")
print(f"FP32 model: {model_fp32_path}")
print(f"FP16 model: {model_fp16_path}")
print(f"INT8 model: {model_int8_path}")