|
def preprocess_image(image_path, image_size=512):
|
|
"""
|
|
Process an image for ImageTagger inference with proper ImageNet normalization
|
|
"""
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
import os
|
|
|
|
if not os.path.exists(image_path):
|
|
raise ValueError(f"Image not found at path: {image_path}")
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225]
|
|
)
|
|
])
|
|
|
|
try:
|
|
with Image.open(image_path) as img:
|
|
|
|
if img.mode in ('RGBA', 'P'):
|
|
img = img.convert('RGB')
|
|
|
|
|
|
width, height = img.size
|
|
aspect_ratio = width / height
|
|
|
|
|
|
if aspect_ratio > 1:
|
|
new_width = image_size
|
|
new_height = int(new_width / aspect_ratio)
|
|
else:
|
|
new_height = image_size
|
|
new_width = int(new_height * aspect_ratio)
|
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
|
|
|
|
|
pad_color = (124, 116, 104)
|
|
new_image = Image.new('RGB', (image_size, image_size), pad_color)
|
|
paste_x = (image_size - new_width) // 2
|
|
paste_y = (image_size - new_height) // 2
|
|
new_image.paste(img, (paste_x, paste_y))
|
|
|
|
|
|
img_tensor = transform(new_image)
|
|
return img_tensor
|
|
|
|
except Exception as e:
|
|
raise Exception(f"Error processing {image_path}: {str(e)}")
|
|
|
|
def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=50):
|
|
"""
|
|
Test ImageTagger ONNX model with proper handling of all outputs
|
|
|
|
Args:
|
|
model_path: Path to ONNX model file
|
|
metadata_path: Path to metadata JSON file
|
|
image_path: Path to test image
|
|
threshold: Confidence threshold for predictions
|
|
top_k: Maximum number of predictions to show
|
|
"""
|
|
import onnxruntime as ort
|
|
import numpy as np
|
|
import json
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
print(f"Loading ImageTagger ONNX model from {model_path}")
|
|
|
|
|
|
try:
|
|
with open(metadata_path, 'r') as f:
|
|
metadata = json.load(f)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to load metadata: {e}")
|
|
|
|
|
|
try:
|
|
dataset_info = metadata['dataset_info']
|
|
tag_mapping = dataset_info['tag_mapping']
|
|
idx_to_tag = tag_mapping['idx_to_tag']
|
|
tag_to_category = tag_mapping['tag_to_category']
|
|
total_tags = dataset_info['total_tags']
|
|
|
|
print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories")
|
|
|
|
except KeyError as e:
|
|
raise ValueError(f"Invalid metadata structure, missing key: {e}")
|
|
|
|
|
|
providers = []
|
|
if ort.get_device() == 'GPU':
|
|
providers.append('CUDAExecutionProvider')
|
|
providers.append('CPUExecutionProvider')
|
|
|
|
try:
|
|
session = ort.InferenceSession(model_path, providers=providers)
|
|
active_provider = session.get_providers()[0]
|
|
print(f"Using provider: {active_provider}")
|
|
|
|
|
|
inputs = session.get_inputs()
|
|
outputs = session.get_outputs()
|
|
print(f"Model inputs: {len(inputs)}")
|
|
print(f"Model outputs: {len(outputs)}")
|
|
for i, output in enumerate(outputs):
|
|
print(f" Output {i}: {output.name} {output.shape}")
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to create ONNX session: {e}")
|
|
|
|
|
|
print(f"Processing image: {image_path}")
|
|
try:
|
|
img_tensor = preprocess_image(image_path, image_size=metadata['model_info']['img_size'])
|
|
img_numpy = img_tensor.unsqueeze(0).numpy()
|
|
print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}")
|
|
|
|
except Exception as e:
|
|
raise ValueError(f"Image preprocessing failed: {e}")
|
|
|
|
|
|
input_name = session.get_inputs()[0].name
|
|
print("Running inference...")
|
|
|
|
start_time = time.time()
|
|
try:
|
|
outputs = session.run(None, {input_name: img_numpy})
|
|
inference_time = time.time() - start_time
|
|
print(f"Inference completed in {inference_time:.4f} seconds")
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Inference failed: {e}")
|
|
|
|
|
|
|
|
if len(outputs) >= 2:
|
|
initial_logits = outputs[0]
|
|
refined_logits = outputs[1]
|
|
selected_candidates = outputs[2] if len(outputs) > 2 else None
|
|
|
|
|
|
main_logits = refined_logits
|
|
print(f"Using refined predictions (shape: {refined_logits.shape})")
|
|
|
|
else:
|
|
|
|
main_logits = outputs[0]
|
|
print(f"Using single output (shape: {main_logits.shape})")
|
|
|
|
|
|
main_probs = 1.0 / (1.0 + np.exp(-main_logits))
|
|
|
|
|
|
predictions_mask = (main_probs >= threshold)
|
|
indices = np.where(predictions_mask[0])[0]
|
|
|
|
if len(indices) == 0:
|
|
print(f"No predictions above threshold {threshold}")
|
|
|
|
top_indices = np.argsort(main_probs[0])[-5:][::-1]
|
|
print("Top 5 predictions:")
|
|
for idx in top_indices:
|
|
idx_str = str(idx)
|
|
tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
|
|
prob = float(main_probs[0, idx])
|
|
print(f" {tag_name}: {prob:.3f}")
|
|
return {}
|
|
|
|
|
|
tags_by_category = defaultdict(list)
|
|
|
|
for idx in indices:
|
|
idx_str = str(idx)
|
|
tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
|
|
category = tag_to_category.get(tag_name, "general")
|
|
prob = float(main_probs[0, idx])
|
|
|
|
tags_by_category[category].append((tag_name, prob))
|
|
|
|
|
|
for category in tags_by_category:
|
|
tags_by_category[category] = sorted(
|
|
tags_by_category[category],
|
|
key=lambda x: x[1],
|
|
reverse=True
|
|
)[:top_k]
|
|
|
|
|
|
total_predictions = sum(len(tags) for tags in tags_by_category.values())
|
|
print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total")
|
|
|
|
|
|
category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating']
|
|
|
|
for category in category_order:
|
|
if category in tags_by_category:
|
|
tags = tags_by_category[category]
|
|
print(f"\n{category.upper()} ({len(tags)}):")
|
|
for tag, prob in tags:
|
|
print(f" {tag}: {prob:.3f}")
|
|
|
|
|
|
for category in sorted(tags_by_category.keys()):
|
|
if category not in category_order:
|
|
tags = tags_by_category[category]
|
|
print(f"\n{category.upper()} ({len(tags)}):")
|
|
for tag, prob in tags:
|
|
print(f" {tag}: {prob:.3f}")
|
|
|
|
|
|
print(f"\nPerformance:")
|
|
print(f" Inference time: {inference_time:.4f}s")
|
|
print(f" Provider: {active_provider}")
|
|
print(f" Max confidence: {main_probs.max():.3f}")
|
|
if total_predictions > 0:
|
|
avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags])
|
|
print(f" Average confidence: {avg_conf:.3f}")
|
|
|
|
return dict(tags_by_category) |