pls help: how to use model.onnx ?

#5
by jennasu - opened

#ENV:debian OS, python 3.7

#code
import onnxruntime as ort
import numpy as np
import onnx
model = onnx.load("/root/model_fp16.onnx")
ir_version = model.ir_version
print(f"Model IR version: {ir_version}")

#ERROR:
model = onnx.load("/root/model_fp16.onnx")
File "/usr/local/lib/python3.7/dist-packages/onnx/init.py", line 170, in load_model
model = load_model_from_string(s, format=format)
File "/usr/local/lib/python3.7/dist-packages/onnx/init.py", line 212, in load_model_from_string
return _deserialize(s, ModelProto())
File "/usr/local/lib/python3.7/dist-packages/onnx/init.py", line 143, in _deserialize
decoded = typing.cast(Optional[int], proto.ParseFromString(s))
google.protobuf.message.DecodeError: Error parsing message

Hugging Face TB Research org
edited Aug 18, 2024

You can import onnx model in a similar way than with transformers:

from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM
import torch
import os

# Specify the local path to your cloned repository
# git clone https://huggingface.co/HuggingFaceTB/SmolLM-1.7B-Instruct
repo_path = "path/to/local/repo"

# cp repo_path/onnx/model.onnx repo_path/model.onnx -> moove the onnx model you want to the main folder

# Load the tokenizer and model from local paths
tokenizer = AutoTokenizer.from_pretrained(repo_path)
model = ORTModelForCausalLM.from_pretrained(repo_path)

# Prepare the input using the chat template
messages = [
    {"role": "system", "content": "You are a helpful AI assistant."},
    {"role": "user", "content": "What is the capital of France ?"}
]

# Apply the chat template
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Generate the response
inputs = tokenizer(input_text, return_tensors="pt")
gen_tokens = model.generate(**inputs, do_sample=True, temperature=0.2, top_p=0.9, min_length=20, max_length=100)

# Decode and print the output
output = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
print(output[0])
This comment has been hidden

Sign up or log in to comment