|  | def preprocess_image(image_path, image_size=512): | 
					
						
						|  | """ | 
					
						
						|  | Process an image for ImageTagger inference with proper ImageNet normalization | 
					
						
						|  | """ | 
					
						
						|  | import torchvision.transforms as transforms | 
					
						
						|  | from PIL import Image | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  | if not os.path.exists(image_path): | 
					
						
						|  | raise ValueError(f"Image not found at path: {image_path}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | transform = transforms.Compose([ | 
					
						
						|  | transforms.ToTensor(), | 
					
						
						|  | transforms.Normalize( | 
					
						
						|  | mean=[0.485, 0.456, 0.406], | 
					
						
						|  | std=[0.229, 0.224, 0.225] | 
					
						
						|  | ) | 
					
						
						|  | ]) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | with Image.open(image_path) as img: | 
					
						
						|  |  | 
					
						
						|  | if img.mode in ('RGBA', 'P'): | 
					
						
						|  | img = img.convert('RGB') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | width, height = img.size | 
					
						
						|  | aspect_ratio = width / height | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if aspect_ratio > 1: | 
					
						
						|  | new_width = image_size | 
					
						
						|  | new_height = int(new_width / aspect_ratio) | 
					
						
						|  | else: | 
					
						
						|  | new_height = image_size | 
					
						
						|  | new_width = int(new_height * aspect_ratio) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pad_color = (124, 116, 104) | 
					
						
						|  | new_image = Image.new('RGB', (image_size, image_size), pad_color) | 
					
						
						|  | paste_x = (image_size - new_width) // 2 | 
					
						
						|  | paste_y = (image_size - new_height) // 2 | 
					
						
						|  | new_image.paste(img, (paste_x, paste_y)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img_tensor = transform(new_image) | 
					
						
						|  | return img_tensor | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise Exception(f"Error processing {image_path}: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=50): | 
					
						
						|  | """ | 
					
						
						|  | Test ImageTagger ONNX model with proper handling of all outputs | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | model_path: Path to ONNX model file | 
					
						
						|  | metadata_path: Path to metadata JSON file | 
					
						
						|  | image_path: Path to test image | 
					
						
						|  | threshold: Confidence threshold for predictions | 
					
						
						|  | top_k: Maximum number of predictions to show | 
					
						
						|  | """ | 
					
						
						|  | import onnxruntime as ort | 
					
						
						|  | import numpy as np | 
					
						
						|  | import json | 
					
						
						|  | import time | 
					
						
						|  | from collections import defaultdict | 
					
						
						|  |  | 
					
						
						|  | print(f"Loading ImageTagger ONNX model from {model_path}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | with open(metadata_path, 'r') as f: | 
					
						
						|  | metadata = json.load(f) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise ValueError(f"Failed to load metadata: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | dataset_info = metadata['dataset_info'] | 
					
						
						|  | tag_mapping = dataset_info['tag_mapping'] | 
					
						
						|  | idx_to_tag = tag_mapping['idx_to_tag'] | 
					
						
						|  | tag_to_category = tag_mapping['tag_to_category'] | 
					
						
						|  | total_tags = dataset_info['total_tags'] | 
					
						
						|  |  | 
					
						
						|  | print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories") | 
					
						
						|  |  | 
					
						
						|  | except KeyError as e: | 
					
						
						|  | raise ValueError(f"Invalid metadata structure, missing key: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | providers = [] | 
					
						
						|  | if ort.get_device() == 'GPU': | 
					
						
						|  | providers.append('CUDAExecutionProvider') | 
					
						
						|  | providers.append('CPUExecutionProvider') | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | session = ort.InferenceSession(model_path, providers=providers) | 
					
						
						|  | active_provider = session.get_providers()[0] | 
					
						
						|  | print(f"Using provider: {active_provider}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | inputs = session.get_inputs() | 
					
						
						|  | outputs = session.get_outputs() | 
					
						
						|  | print(f"Model inputs: {len(inputs)}") | 
					
						
						|  | print(f"Model outputs: {len(outputs)}") | 
					
						
						|  | for i, output in enumerate(outputs): | 
					
						
						|  | print(f"  Output {i}: {output.name} {output.shape}") | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise RuntimeError(f"Failed to create ONNX session: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print(f"Processing image: {image_path}") | 
					
						
						|  | try: | 
					
						
						|  | img_tensor = preprocess_image(image_path, image_size=metadata['model_info']['img_size']) | 
					
						
						|  | img_numpy = img_tensor.unsqueeze(0).numpy() | 
					
						
						|  | print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}") | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise ValueError(f"Image preprocessing failed: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | input_name = session.get_inputs()[0].name | 
					
						
						|  | print("Running inference...") | 
					
						
						|  |  | 
					
						
						|  | start_time = time.time() | 
					
						
						|  | try: | 
					
						
						|  | outputs = session.run(None, {input_name: img_numpy}) | 
					
						
						|  | inference_time = time.time() - start_time | 
					
						
						|  | print(f"Inference completed in {inference_time:.4f} seconds") | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise RuntimeError(f"Inference failed: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(outputs) >= 2: | 
					
						
						|  | initial_logits = outputs[0] | 
					
						
						|  | refined_logits = outputs[1] | 
					
						
						|  | selected_candidates = outputs[2] if len(outputs) > 2 else None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | main_logits = refined_logits | 
					
						
						|  | print(f"Using refined predictions (shape: {refined_logits.shape})") | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | main_logits = outputs[0] | 
					
						
						|  | print(f"Using single output (shape: {main_logits.shape})") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | main_probs = 1.0 / (1.0 + np.exp(-main_logits)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | predictions_mask = (main_probs >= threshold) | 
					
						
						|  | indices = np.where(predictions_mask[0])[0] | 
					
						
						|  |  | 
					
						
						|  | if len(indices) == 0: | 
					
						
						|  | print(f"No predictions above threshold {threshold}") | 
					
						
						|  |  | 
					
						
						|  | top_indices = np.argsort(main_probs[0])[-5:][::-1] | 
					
						
						|  | print("Top 5 predictions:") | 
					
						
						|  | for idx in top_indices: | 
					
						
						|  | idx_str = str(idx) | 
					
						
						|  | tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}") | 
					
						
						|  | prob = float(main_probs[0, idx]) | 
					
						
						|  | print(f"  {tag_name}: {prob:.3f}") | 
					
						
						|  | return {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tags_by_category = defaultdict(list) | 
					
						
						|  |  | 
					
						
						|  | for idx in indices: | 
					
						
						|  | idx_str = str(idx) | 
					
						
						|  | tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}") | 
					
						
						|  | category = tag_to_category.get(tag_name, "general") | 
					
						
						|  | prob = float(main_probs[0, idx]) | 
					
						
						|  |  | 
					
						
						|  | tags_by_category[category].append((tag_name, prob)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for category in tags_by_category: | 
					
						
						|  | tags_by_category[category] = sorted( | 
					
						
						|  | tags_by_category[category], | 
					
						
						|  | key=lambda x: x[1], | 
					
						
						|  | reverse=True | 
					
						
						|  | )[:top_k] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | total_predictions = sum(len(tags) for tags in tags_by_category.values()) | 
					
						
						|  | print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating'] | 
					
						
						|  |  | 
					
						
						|  | for category in category_order: | 
					
						
						|  | if category in tags_by_category: | 
					
						
						|  | tags = tags_by_category[category] | 
					
						
						|  | print(f"\n{category.upper()} ({len(tags)}):") | 
					
						
						|  | for tag, prob in tags: | 
					
						
						|  | print(f"  {tag}: {prob:.3f}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for category in sorted(tags_by_category.keys()): | 
					
						
						|  | if category not in category_order: | 
					
						
						|  | tags = tags_by_category[category] | 
					
						
						|  | print(f"\n{category.upper()} ({len(tags)}):") | 
					
						
						|  | for tag, prob in tags: | 
					
						
						|  | print(f"  {tag}: {prob:.3f}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print(f"\nPerformance:") | 
					
						
						|  | print(f"  Inference time: {inference_time:.4f}s") | 
					
						
						|  | print(f"  Provider: {active_provider}") | 
					
						
						|  | print(f"  Max confidence: {main_probs.max():.3f}") | 
					
						
						|  | if total_predictions > 0: | 
					
						
						|  | avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags]) | 
					
						
						|  | print(f"  Average confidence: {avg_conf:.3f}") | 
					
						
						|  |  | 
					
						
						|  | return dict(tags_by_category) |