Spaces:
Runtime error
Runtime error
import onnxruntime as ort | |
import numpy as np | |
import json | |
from PIL import Image | |
def preprocess_image(img_path, target_size=512, keep_aspect=True): | |
""" | |
Load an image from img_path, convert to RGB, | |
and resize/pad to (target_size, target_size). | |
Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array. | |
""" | |
img = Image.open(img_path).convert("RGB") | |
if keep_aspect: | |
# Preserve aspect ratio, pad black | |
w, h = img.size | |
aspect = w / h | |
if aspect > 1: | |
new_w = target_size | |
new_h = int(new_w / aspect) | |
else: | |
new_h = target_size | |
new_w = int(new_h * aspect) | |
# Resize with Lanczos | |
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
# Pad to a square | |
background = Image.new("RGB", (target_size, target_size), (0, 0, 0)) | |
paste_x = (target_size - new_w) // 2 | |
paste_y = (target_size - new_h) // 2 | |
background.paste(img, (paste_x, paste_y)) | |
img = background | |
else: | |
# simple direct resize to 512x512 | |
img = img.resize((target_size, target_size), Image.Resampling.LANCZOS) | |
# Convert to numpy array | |
arr = np.array(img).astype("float32") / 255.0 # scale to [0,1] | |
# Transpose from HWC -> CHW | |
arr = np.transpose(arr, (2, 0, 1)) | |
# Add batch dimension: (1,3,512,512) | |
arr = np.expand_dims(arr, axis=0) | |
return arr | |
def onnx_inference(img_paths, | |
onnx_path="camie_refined_no_flash.onnx", | |
threshold=0.325, | |
metadata_file="metadata.json"): | |
""" | |
Loads the ONNX model, runs inference on a list of image paths, | |
and applies an optional threshold to produce final predictions. | |
Args: | |
img_paths: List of paths to images. | |
onnx_path: Path to the exported ONNX model file. | |
threshold: Probability threshold for deciding if a tag is predicted. | |
metadata_file: Path to metadata.json that contains idx_to_tag etc. | |
Returns: | |
A list of dicts, each containing: | |
{ | |
"initial_logits": np.ndarray of shape (N_tags,), | |
"refined_logits": np.ndarray of shape (N_tags,), | |
"predicted_tags": list of tag indices that exceeded threshold, | |
... | |
} | |
one dict per input image. | |
""" | |
# 1) Initialize ONNX runtime session | |
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) | |
# Optional: for GPU usage, see if "CUDAExecutionProvider" is available | |
# session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"]) | |
# 2) Pre-load metadata | |
with open(metadata_file, "r", encoding="utf-8") as f: | |
metadata = json.load(f) | |
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... } | |
# 3) Preprocess each image into a batch | |
batch_tensors = [] | |
for img_path in img_paths: | |
x = preprocess_image(img_path, target_size=512, keep_aspect=True) | |
batch_tensors.append(x) | |
# Concatenate along the batch dimension => shape (batch_size, 3, 512, 512) | |
batch_input = np.concatenate(batch_tensors, axis=0) | |
# 4) Run inference | |
input_name = session.get_inputs()[0].name # typically "image" | |
outputs = session.run(None, {input_name: batch_input}) | |
# Typically we get [initial_tags, refined_tags] as output | |
initial_preds, refined_preds = outputs # shapes => (batch_size, 70527) | |
# 5) For each image in batch, convert logits to predictions if desired | |
batch_results = [] | |
for i in range(initial_preds.shape[0]): | |
# Extract one sample's logits | |
init_logit = initial_preds[i, :] # shape (N_tags,) | |
ref_logit = refined_preds[i, :] # shape (N_tags,) | |
# Convert to probabilities with sigmoid | |
ref_prob = 1.0 / (1.0 + np.exp(-ref_logit)) | |
# Threshold | |
pred_indices = np.where(ref_prob >= threshold)[0] | |
# Build result for this image | |
result_dict = { | |
"initial_logits": init_logit, | |
"refined_logits": ref_logit, | |
"predicted_indices": pred_indices, | |
"predicted_tags": [idx_to_tag[str(idx)] for idx in pred_indices] # map index->tag name | |
} | |
batch_results.append(result_dict) | |
return batch_results | |
if __name__ == "__main__": | |
# Example usage | |
images = ["image1.jpg", "image2.jpg", "image3.jpg"] | |
results = onnx_inference(images, | |
onnx_path="camie_refined_no_flash.onnx", | |
threshold=0.325, | |
metadata_file="metadata.json") | |
for i, res in enumerate(results): | |
print(f"Image: {images[i]}") | |
print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}") | |
print(f" Some predicted tags: {res['predicted_tags'][:10]} (Show up to 10)") | |
print() |