File size: 8,724 Bytes
53766b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def preprocess_image(image_path, image_size=512):
    """

    Process an image for ImageTagger inference with proper ImageNet normalization

    """
    import torchvision.transforms as transforms
    from PIL import Image
    import os
    
    if not os.path.exists(image_path):
        raise ValueError(f"Image not found at path: {image_path}")
    
    # ImageNet normalization - CRITICAL for your model
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    try:
        with Image.open(image_path) as img:
            # Convert RGBA or Palette images to RGB
            if img.mode in ('RGBA', 'P'):
                img = img.convert('RGB')
            
            # Get original dimensions
            width, height = img.size
            aspect_ratio = width / height
            
            # Calculate new dimensions to maintain aspect ratio
            if aspect_ratio > 1:
                new_width = image_size
                new_height = int(new_width / aspect_ratio)
            else:
                new_height = image_size
                new_width = int(new_height * aspect_ratio)
            
            # Resize with LANCZOS filter
            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
            
            # Create new image with padding (use ImageNet mean for padding)
            # Using RGB values close to ImageNet mean: (0.485*255, 0.456*255, 0.406*255)
            pad_color = (124, 116, 104)
            new_image = Image.new('RGB', (image_size, image_size), pad_color)
            paste_x = (image_size - new_width) // 2
            paste_y = (image_size - new_height) // 2
            new_image.paste(img, (paste_x, paste_y))
            
            # Apply transforms (including ImageNet normalization)
            img_tensor = transform(new_image)
            return img_tensor
            
    except Exception as e:
        raise Exception(f"Error processing {image_path}: {str(e)}")

def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=50):
    """

    Test ImageTagger ONNX model with proper handling of all outputs

    

    Args:

        model_path: Path to ONNX model file

        metadata_path: Path to metadata JSON file  

        image_path: Path to test image

        threshold: Confidence threshold for predictions

        top_k: Maximum number of predictions to show

    """
    import onnxruntime as ort
    import numpy as np
    import json
    import time
    from collections import defaultdict
    
    print(f"Loading ImageTagger ONNX model from {model_path}")
    
    # Load metadata with proper error handling
    try:
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
    except Exception as e:
        raise ValueError(f"Failed to load metadata: {e}")
    
    # Extract tag mappings from nested structure
    try:
        dataset_info = metadata['dataset_info']
        tag_mapping = dataset_info['tag_mapping']
        idx_to_tag = tag_mapping['idx_to_tag']
        tag_to_category = tag_mapping['tag_to_category']
        total_tags = dataset_info['total_tags']
        
        print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories")
        
    except KeyError as e:
        raise ValueError(f"Invalid metadata structure, missing key: {e}")
    
    # Initialize ONNX session with robust provider handling
    providers = []
    if ort.get_device() == 'GPU':
        providers.append('CUDAExecutionProvider')
    providers.append('CPUExecutionProvider')
    
    try:
        session = ort.InferenceSession(model_path, providers=providers)
        active_provider = session.get_providers()[0]
        print(f"Using provider: {active_provider}")
        
        # Print model info
        inputs = session.get_inputs()
        outputs = session.get_outputs()
        print(f"Model inputs: {len(inputs)}")
        print(f"Model outputs: {len(outputs)}")
        for i, output in enumerate(outputs):
            print(f"  Output {i}: {output.name} {output.shape}")
            
    except Exception as e:
        raise RuntimeError(f"Failed to create ONNX session: {e}")
    
    # Preprocess image
    print(f"Processing image: {image_path}")
    try:
        img_tensor = preprocess_image(image_path, image_size=metadata['model_info']['img_size'])
        img_numpy = img_tensor.unsqueeze(0).numpy()  # Add batch dimension
        print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}")
        
    except Exception as e:
        raise ValueError(f"Image preprocessing failed: {e}")
    
    # Run inference
    input_name = session.get_inputs()[0].name
    print("Running inference...")
    
    start_time = time.time()
    try:
        outputs = session.run(None, {input_name: img_numpy})
        inference_time = time.time() - start_time
        print(f"Inference completed in {inference_time:.4f} seconds")
        
    except Exception as e:
        raise RuntimeError(f"Inference failed: {e}")
    
    # Handle outputs properly
    # outputs[0] = initial_predictions, outputs[1] = refined_predictions, outputs[2] = selected_candidates
    if len(outputs) >= 2:
        initial_logits = outputs[0]
        refined_logits = outputs[1] 
        selected_candidates = outputs[2] if len(outputs) > 2 else None
        
        # Use refined predictions as main output
        main_logits = refined_logits
        print(f"Using refined predictions (shape: {refined_logits.shape})")
        
    else:
        # Fallback to single output
        main_logits = outputs[0]
        print(f"Using single output (shape: {main_logits.shape})")
    
    # Apply sigmoid to get probabilities
    main_probs = 1.0 / (1.0 + np.exp(-main_logits))
    
    # Apply threshold and get predictions
    predictions_mask = (main_probs >= threshold)
    indices = np.where(predictions_mask[0])[0]
    
    if len(indices) == 0:
        print(f"No predictions above threshold {threshold}")
        # Show top 5 regardless of threshold
        top_indices = np.argsort(main_probs[0])[-5:][::-1]
        print("Top 5 predictions:")
        for idx in top_indices:
            idx_str = str(idx)
            tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
            prob = float(main_probs[0, idx])
            print(f"  {tag_name}: {prob:.3f}")
        return {}
    
    # Group by category  
    tags_by_category = defaultdict(list)
    
    for idx in indices:
        idx_str = str(idx)
        tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}")
        category = tag_to_category.get(tag_name, "general")
        prob = float(main_probs[0, idx])
        
        tags_by_category[category].append((tag_name, prob))
    
    # Sort by probability within each category
    for category in tags_by_category:
        tags_by_category[category] = sorted(
            tags_by_category[category], 
            key=lambda x: x[1], 
            reverse=True
        )[:top_k]  # Limit per category
    
    # Print results
    total_predictions = sum(len(tags) for tags in tags_by_category.values())
    print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total")
    
    # Category order for consistent display
    category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating']
    
    for category in category_order:
        if category in tags_by_category:
            tags = tags_by_category[category]
            print(f"\n{category.upper()} ({len(tags)}):")
            for tag, prob in tags:
                print(f"  {tag}: {prob:.3f}")
    
    # Show any other categories not in standard order
    for category in sorted(tags_by_category.keys()):
        if category not in category_order:
            tags = tags_by_category[category]
            print(f"\n{category.upper()} ({len(tags)}):")
            for tag, prob in tags:
                print(f"  {tag}: {prob:.3f}")
    
    # Performance stats
    print(f"\nPerformance:")
    print(f"  Inference time: {inference_time:.4f}s")
    print(f"  Provider: {active_provider}")
    print(f"  Max confidence: {main_probs.max():.3f}")
    if total_predictions > 0:
        avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags])
        print(f"  Average confidence: {avg_conf:.3f}")
    
    return dict(tags_by_category)