YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

How to load the model and make inferences

Download all the files to a local directory model_dir

Initiate ONNX Session

from torch.onnx import export
import onnxruntime as ort

session = ort.InferenceSession(model_dir + "/custom_bart.onnx")

Load Tokenizers

input_tokenizer = BartTokenizer.from_pretrained(model_dir + "/input_tokenizer")
output_tokenizer= PreTrainedTokenizerFast.from_pretrained(model_dir + "/output_tokenizer")

Set up special tokens

bos_token_id = output_tokenizer.bos_token_id
eos_token_id = output_tokenizer.eos_token_id
pad_token_id = output_tokenizer.pad_token_id

Inference

# add custom decoding logic
import re
def remove_intra_word_spaces(text):
    # Remove special tokens first (optional, if needed)
    text = text.replace("<s>", "").replace("</s>", "").strip()

    # Step 1: Split on 2+ spaces (which indicate word boundaries)
    words = re.split(r'\s{2,}', text)

    # Step 2: For each word, remove all single spaces (intra-word spacing)
    cleaned_words = [''.join(word.split()) for word in words]

    # Step 3: Join words back with a single space
    return ' '.join(cleaned_words)

Custom inference function from the onnx session

def greedy_decode_onnx_full_model(input_text, max_length = 512, input_length = 128):
    # Encode input
    inputs = input_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=input_length)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Initialize decoder with BOS
    decoder_input_ids = np.array([[bos_token_id]], dtype=np.int64)

    for _ in range(max_length):
        # Run ONNX forward
        ort_inputs = {
            "input_ids": input_ids.astype(np.int64),
            "attention_mask": attention_mask.astype(np.int64),
            "decoder_input_ids": decoder_input_ids.astype(np.int64),
        }

        logits = session.run(["logits"], ort_inputs)[0]
        next_token_logits = logits[:, -1, :]  # (batch, vocab)
        next_token_id = np.argmax(next_token_logits, axis=-1).reshape(1, 1)  # (1, 1)

        # Append new token to decoder input
        decoder_input_ids = np.concatenate([decoder_input_ids, next_token_id], axis=-1)

        if next_token_id[0][0] == eos_token_id:
            break

    # Decode final tokens
    decoded_text = output_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=False)
    return decoded_text

Example:

text = "This is a test."
output = greedy_decode_onnx_full_model(text)
cleaned = remove_intra_word_spaces(output)
print("Raw output:", output)
print("Cleaned:", cleaned)

This should return:

Raw output: <s> ð ˌ ɪ s   ɪ z   ɐ   t ˈ ɛ s t . </s>
Cleaned: ðˌɪs ɪz ɐ tˈɛst.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support