Upload ONNX model files
Browse files
README.md
CHANGED
@@ -1,3 +1,126 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: onnx
|
3 |
+
tags:
|
4 |
+
- text-reranking
|
5 |
+
- jina
|
6 |
+
- onnx
|
7 |
+
- fp16
|
8 |
+
pipeline_tag: sentence-similarity
|
9 |
+
---
|
10 |
+
|
11 |
+
# Jina Reranker M0 - ONNX FP16 Version
|
12 |
+
|
13 |
+
This repository contains the [jinaai/jina-reranker-m0](https://huggingface.co/jinaai/jina-reranker-m0) model converted to the ONNX format with FP16 precision.
|
14 |
+
|
15 |
+
## Model Description
|
16 |
+
|
17 |
+
Jina Reranker is designed to rerank search results or document passages based on their relevance to a given query. It takes a query and a list of documents as input and outputs relevance scores.
|
18 |
+
|
19 |
+
This version is specifically exported for use with ONNX Runtime.
|
20 |
+
|
21 |
+
**Original Model Card:** [jinaai/jina-reranker-m0](https://huggingface.co/jinaai/jina-reranker-m0)
|
22 |
+
|
23 |
+
## Technical Details
|
24 |
+
|
25 |
+
* **Format:** ONNX
|
26 |
+
* **Opset:** 14
|
27 |
+
* **Precision:** FP16 (exported using `.half()`)
|
28 |
+
* **External Data:** Uses ONNX external data format due to model size. All files in this repository are required. `huggingface_hub` handles downloading them automatically.
|
29 |
+
* **Export Source:** Exported from the Hugging Face `transformers` library using `torch.onnx.export`.
|
30 |
+
|
31 |
+
## Usage
|
32 |
+
|
33 |
+
You can use this model with `onnxruntime` for inference. You will also need the `transformers` library to load the appropriate processor for input preparation and `huggingface_hub` to download the model files.
|
34 |
+
|
35 |
+
**1. Installation:**
|
36 |
+
|
37 |
+
```bash
|
38 |
+
pip install onnxruntime huggingface_hub transformers torch sentencepiece
|
39 |
+
```
|
40 |
+
|
41 |
+
**2. Inference Script:**
|
42 |
+
|
43 |
+
```python
|
44 |
+
import onnxruntime as ort
|
45 |
+
from huggingface_hub import hf_hub_download
|
46 |
+
from transformers import AutoProcessor
|
47 |
+
import numpy as np
|
48 |
+
import torch # For processor output handling
|
49 |
+
|
50 |
+
# --- Configuration ---
|
51 |
+
# Replace with your repository ID if different
|
52 |
+
repo_id = "jian-mo/jina-reranker-m0-onnx"
|
53 |
+
onnx_filename = "jina-reranker-m0.onnx" # Main ONNX file name
|
54 |
+
# Use the original model ID to load the correct processor
|
55 |
+
original_model_id = "jinaai/jina-reranker-m0"
|
56 |
+
# --- End Configuration ---
|
57 |
+
|
58 |
+
# 1. Download ONNX model files from the Hub
|
59 |
+
# hf_hub_download automatically handles external data files linked via LFS
|
60 |
+
print(f"Downloading ONNX model from {repo_id}...")
|
61 |
+
local_onnx_path = hf_hub_download(
|
62 |
+
repo_id=repo_id,
|
63 |
+
filename=onnx_filename
|
64 |
+
)
|
65 |
+
print(f"ONNX model downloaded to: {local_onnx_path}")
|
66 |
+
|
67 |
+
# 2. Load ONNX Runtime session
|
68 |
+
print("Loading ONNX Inference Session...")
|
69 |
+
# You can choose execution providers, e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
70 |
+
# if you have GPU support and the necessary onnxruntime build.
|
71 |
+
session_options = ort.SessionOptions()
|
72 |
+
# session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
73 |
+
providers = ['CPUExecutionProvider'] # Default to CPU
|
74 |
+
session = ort.InferenceSession(local_onnx_path, sess_options=session_options, providers=providers)
|
75 |
+
print(f"ONNX session loaded with provider: {session.get_providers()}")
|
76 |
+
|
77 |
+
# 3. Load the Processor
|
78 |
+
print(f"Loading processor from {original_model_id}...")
|
79 |
+
processor = AutoProcessor.from_pretrained(original_model_id, trust_remote_code=True)
|
80 |
+
print("Processor loaded.")
|
81 |
+
|
82 |
+
# 4. Prepare Input Data
|
83 |
+
query = "What is deep learning?"
|
84 |
+
document = "Deep learning is a subset of machine learning based on artificial neural networks with representation learning."
|
85 |
+
# Example with multiple documents (batch processing)
|
86 |
+
# documents = [
|
87 |
+
# "Deep learning is a subset of machine learning based on artificial neural networks with representation learning.",
|
88 |
+
# "Artificial intelligence refers to the simulation of human intelligence in machines.",
|
89 |
+
# "A transformer is a deep learning model used primarily in the field of natural language processing."
|
90 |
+
# ]
|
91 |
+
# Use processor logic suitable for query + multiple documents if needed
|
92 |
+
|
93 |
+
print("Preparing input data...")
|
94 |
+
# Process query and document together as expected by the reranker model
|
95 |
+
inputs = processor(
|
96 |
+
text=f"{query} {document}",
|
97 |
+
images=None, # Assuming text-only reranking
|
98 |
+
return_tensors="pt", # Get PyTorch tensors first
|
99 |
+
padding=True,
|
100 |
+
truncation=True,
|
101 |
+
max_length=512 # Use a reasonable max_length
|
102 |
+
)
|
103 |
+
|
104 |
+
# Convert to NumPy for ONNX Runtime
|
105 |
+
inputs_np = {
|
106 |
+
"input_ids": inputs["input_ids"].numpy(),
|
107 |
+
"attention_mask": inputs["attention_mask"].numpy()
|
108 |
+
}
|
109 |
+
print("Input data prepared.")
|
110 |
+
# print("Input shapes:", {k: v.shape for k, v in inputs_np.items()})
|
111 |
+
|
112 |
+
# 5. Run Inference
|
113 |
+
print("Running inference...")
|
114 |
+
output_names = [output.name for output in session.get_outputs()]
|
115 |
+
outputs = session.run(output_names, inputs_np)
|
116 |
+
print("Inference complete.")
|
117 |
+
|
118 |
+
# 6. Process Output
|
119 |
+
# The exact interpretation depends on the model's output structure.
|
120 |
+
# For Jina Reranker, the output is typically a logit score.
|
121 |
+
# Higher values usually indicate higher relevance. Check the original model card.
|
122 |
+
print(f"Number of outputs: {len(outputs)}")
|
123 |
+
if len(outputs) > 0:
|
124 |
+
logits = outputs[0]
|
125 |
+
print(f"Output logits shape: {logits.shape}")
|
126 |
+
# Often, the relevance score is associated
|