camie-tagger-v2 / onnx_inference.py
Camais03's picture
Upload 130 files
53766b0 verified
raw
history blame
8.72 kB
def preprocess_image(image_path, image_size=512):
"""
Process an image for ImageTagger inference with proper ImageNet normalization
"""
import torchvision.transforms as transforms
from PIL import Image
import os
if not os.path.exists(image_path):
raise ValueError(f"Image not found at path: {image_path}")
# ImageNet normalization - CRITICAL for your model
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
try:
with Image.open(image_path) as img:
# Convert RGBA or Palette images to RGB
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# Get original dimensions
width, height = img.size
aspect_ratio = width / height
# Calculate new dimensions to maintain aspect ratio
if aspect_ratio > 1:
new_width = image_size
new_height = int(new_width / aspect_ratio)
else:
new_height = image_size
new_width = int(new_height * aspect_ratio)
# Resize with LANCZOS filter
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with padding (use ImageNet mean for padding)
# Using RGB values close to ImageNet mean: (0.485*255, 0.456*255, 0.406*255)
pad_color = (124, 116, 104)
new_image = Image.new('RGB', (image_size, image_size), pad_color)
paste_x = (image_size - new_width) // 2
paste_y = (image_size - new_height) // 2
new_image.paste(img, (paste_x, paste_y))
# Apply transforms (including ImageNet normalization)
img_tensor = transform(new_image)
return img_tensor
except Exception as e:
raise Exception(f"Error processing {image_path}: {str(e)}")
def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=50):
"""
Test ImageTagger ONNX model with proper handling of all outputs
Args:
model_path: Path to ONNX model file
metadata_path: Path to metadata JSON file
image_path: Path to test image
threshold: Confidence threshold for predictions
top_k: Maximum number of predictions to show
"""
import onnxruntime as ort
import numpy as np
import json
import time
from collections import defaultdict
print(f"Loading ImageTagger ONNX model from {model_path}")
# Load metadata with proper error handling
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
except Exception as e:
raise ValueError(f"Failed to load metadata: {e}")
# Extract tag mappings from nested structure
try:
dataset_info = metadata['dataset_info']
tag_mapping = dataset_info['tag_mapping']
idx_to_tag = tag_mapping['idx_to_tag']
tag_to_category = tag_mapping['tag_to_category']
total_tags = dataset_info['total_tags']
print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories")
except KeyError as e:
raise ValueError(f"Invalid metadata structure, missing key: {e}")
# Initialize ONNX session with robust provider handling
providers = []
if ort.get_device() == 'GPU':
providers.append('CUDAExecutionProvider')
providers.append('CPUExecutionProvider')
try:
session = ort.InferenceSession(model_path, providers=providers)
active_provider = session.get_providers()[0]
print(f"Using provider: {active_provider}")
# Print model info
inputs = session.get_inputs()
outputs = session.get_outputs()
print(f"Model inputs: {len(inputs)}")
print(f"Model outputs: {len(outputs)}")
for i, output in enumerate(outputs):
print(f" Output {i}: {output.name} {output.shape}")
except Exception as e:
raise RuntimeError(f"Failed to create ONNX session: {e}")
# Preprocess image
print(f"Processing image: {image_path}")
try:
img_tensor = preprocess_image(image_path, image_size=metadata['model_info']['img_size'])
img_numpy = img_tensor.unsqueeze(0).numpy() # Add batch dimension
print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}")
except Exception as e:
raise ValueError(f"Image preprocessing failed: {e}")
# Run inference
input_name = session.get_inputs()[0].name
print("Running inference...")
start_time = time.time()
try:
outputs = session.run(None, {input_name: img_numpy})
inference_time = time.time() - start_time
print(f"Inference completed in {inference_time:.4f} seconds")
except Exception as e:
raise RuntimeError(f"Inference failed: {e}")
# Handle outputs properly
# outputs[0] = initial_predictions, outputs[1] = refined_predictions, outputs[2] = selected_candidates
if len(outputs) >= 2:
initial_logits = outputs[0]
refined_logits = outputs[1]
selected_candidates = outputs[2] if len(outputs) > 2 else None
# Use refined predictions as main output
main_logits = refined_logits
print(f"Using refined predictions (shape: {refined_logits.shape})")
else:
# Fallback to single output
main_logits = outputs[0]
print(f"Using single output (shape: {main_logits.shape})")
# Apply sigmoid to get probabilities
main_probs = 1.0 / (1.0 + np.exp(-main_logits))
# Apply threshold and get predictions
predictions_mask = (main_probs >= threshold)
indices = np.where(predictions_mask[0])[0]
if len(indices) == 0:
print(f"No predictions above threshold {threshold}")
# Show top 5 regardless of threshold
top_indices = np.argsort(main_probs[0])[-5:][::-1]
print("Top 5 predictions:")
for idx in top_indices:
idx_str = str(idx)
tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
prob = float(main_probs[0, idx])
print(f" {tag_name}: {prob:.3f}")
return {}
# Group by category
tags_by_category = defaultdict(list)
for idx in indices:
idx_str = str(idx)
tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
category = tag_to_category.get(tag_name, "general")
prob = float(main_probs[0, idx])
tags_by_category[category].append((tag_name, prob))
# Sort by probability within each category
for category in tags_by_category:
tags_by_category[category] = sorted(
tags_by_category[category],
key=lambda x: x[1],
reverse=True
)[:top_k] # Limit per category
# Print results
total_predictions = sum(len(tags) for tags in tags_by_category.values())
print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total")
# Category order for consistent display
category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating']
for category in category_order:
if category in tags_by_category:
tags = tags_by_category[category]
print(f"\n{category.upper()} ({len(tags)}):")
for tag, prob in tags:
print(f" {tag}: {prob:.3f}")
# Show any other categories not in standard order
for category in sorted(tags_by_category.keys()):
if category not in category_order:
tags = tags_by_category[category]
print(f"\n{category.upper()} ({len(tags)}):")
for tag, prob in tags:
print(f" {tag}: {prob:.3f}")
# Performance stats
print(f"\nPerformance:")
print(f" Inference time: {inference_time:.4f}s")
print(f" Provider: {active_provider}")
print(f" Max confidence: {main_probs.max():.3f}")
if total_predictions > 0:
avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags])
print(f" Average confidence: {avg_conf:.3f}")
return dict(tags_by_category)