File size: 3,431 Bytes
f33bb6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""
Example inference script for card_segmentation model.
"""
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
def preprocess_image(image_path, target_size=(320, 240)):
"""
Preprocess image for model inference.
Args:
image_path (str): Path to input image
target_size (tuple): Target image size (H, W)
Returns:
torch.Tensor: Preprocessed image tensor
"""
# Load image
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Resize
image = cv2.resize(image, (target_size[1], target_size[0]))
# Normalize
image = image.astype(np.float32) / 255.0
image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
# Convert to tensor and add batch dimension
image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
return image_tensor
def postprocess_output(output):
"""
Postprocess model output to get segmentation mask.
Args:
output: Model output tensor
Returns:
np.ndarray: Binary segmentation mask
"""
# Apply softmax and get predictions
probs = F.softmax(output, dim=1)
pred_mask = torch.argmax(probs, dim=1)
return pred_mask.cpu().numpy()[0]
def inference_pytorch(model_path, image_path):
"""
Run inference using PyTorch model.
"""
# Load model
model = torch.jit.load(model_path, map_location='cpu')
model.eval()
# Preprocess image
input_tensor = preprocess_image(image_path)
# Run inference
with torch.no_grad():
output = model(input_tensor)
# Postprocess
mask = postprocess_output(output)
return mask
def inference_onnx(model_path, image_path):
"""
Run inference using ONNX model.
"""
# Load ONNX model
session = ort.InferenceSession(model_path)
# Preprocess image
input_tensor = preprocess_image(image_path)
input_array = input_tensor.numpy()
# Run inference
input_name = session.get_inputs()[0].name
output = session.run(None, {input_name: input_array})[0]
# Postprocess
output_tensor = torch.from_numpy(output)
mask = postprocess_output(output_tensor)
return mask
def save_mask(mask, output_path):
"""Save segmentation mask as image."""
# Convert to 0-255 range
mask_image = (mask * 255).astype(np.uint8)
cv2.imwrite(output_path, mask_image)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run inference on card segmentation model')
parser.add_argument('--model', type=str, required=True, help='Path to model file')
parser.add_argument('--image', type=str, required=True, help='Path to input image')
parser.add_argument('--output', type=str, default='output_mask.png', help='Output mask path')
parser.add_argument('--format', type=str, choices=['pytorch', 'onnx'], default='onnx',
help='Model format')
args = parser.parse_args()
# Run inference
if args.format == 'pytorch':
mask = inference_pytorch(args.model, args.image)
else:
mask = inference_onnx(args.model, args.image)
# Save result
save_mask(mask, args.output)
print(f"Segmentation mask saved to: {args.output}")
|