ONNX
English
File size: 3,431 Bytes
f33bb6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Example inference script for card_segmentation model.
"""
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort

def preprocess_image(image_path, target_size=(320, 240)):
    """
    Preprocess image for model inference.
    
    Args:
        image_path (str): Path to input image
        target_size (tuple): Target image size (H, W)
        
    Returns:
        torch.Tensor: Preprocessed image tensor
    """
    # Load image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Resize
    image = cv2.resize(image, (target_size[1], target_size[0]))
    
    # Normalize
    image = image.astype(np.float32) / 255.0
    image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
    
    # Convert to tensor and add batch dimension
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
    
    return image_tensor

def postprocess_output(output):
    """
    Postprocess model output to get segmentation mask.
    
    Args:
        output: Model output tensor
        
    Returns:
        np.ndarray: Binary segmentation mask
    """
    # Apply softmax and get predictions
    probs = F.softmax(output, dim=1)
    pred_mask = torch.argmax(probs, dim=1)
    
    return pred_mask.cpu().numpy()[0]

def inference_pytorch(model_path, image_path):
    """
    Run inference using PyTorch model.
    """
    # Load model
    model = torch.jit.load(model_path, map_location='cpu')
    model.eval()
    
    # Preprocess image
    input_tensor = preprocess_image(image_path)
    
    # Run inference
    with torch.no_grad():
        output = model(input_tensor)
    
    # Postprocess
    mask = postprocess_output(output)
    
    return mask

def inference_onnx(model_path, image_path):
    """
    Run inference using ONNX model.
    """
    # Load ONNX model
    session = ort.InferenceSession(model_path)
    
    # Preprocess image
    input_tensor = preprocess_image(image_path)
    input_array = input_tensor.numpy()
    
    # Run inference
    input_name = session.get_inputs()[0].name
    output = session.run(None, {input_name: input_array})[0]
    
    # Postprocess
    output_tensor = torch.from_numpy(output)
    mask = postprocess_output(output_tensor)
    
    return mask

def save_mask(mask, output_path):
    """Save segmentation mask as image."""
    # Convert to 0-255 range
    mask_image = (mask * 255).astype(np.uint8)
    cv2.imwrite(output_path, mask_image)

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Run inference on card segmentation model')
    parser.add_argument('--model', type=str, required=True, help='Path to model file')
    parser.add_argument('--image', type=str, required=True, help='Path to input image')
    parser.add_argument('--output', type=str, default='output_mask.png', help='Output mask path')
    parser.add_argument('--format', type=str, choices=['pytorch', 'onnx'], default='onnx',
                       help='Model format')
    
    args = parser.parse_args()
    
    # Run inference
    if args.format == 'pytorch':
        mask = inference_pytorch(args.model, args.image)
    else:
        mask = inference_onnx(args.model, args.image)
    
    # Save result
    save_mask(mask, args.output)
    print(f"Segmentation mask saved to: {args.output}")