|
|
""" |
|
|
Example inference script for card_segmentation model. |
|
|
""" |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import onnxruntime as ort |
|
|
|
|
|
def preprocess_image(image_path, target_size=(320, 240)): |
|
|
""" |
|
|
Preprocess image for model inference. |
|
|
|
|
|
Args: |
|
|
image_path (str): Path to input image |
|
|
target_size (tuple): Target image size (H, W) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Preprocessed image tensor |
|
|
""" |
|
|
|
|
|
image = cv2.imread(image_path) |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
image = cv2.resize(image, (target_size[1], target_size[0])) |
|
|
|
|
|
|
|
|
image = image.astype(np.float32) / 255.0 |
|
|
image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225]) |
|
|
|
|
|
|
|
|
image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) |
|
|
|
|
|
return image_tensor |
|
|
|
|
|
def postprocess_output(output): |
|
|
""" |
|
|
Postprocess model output to get segmentation mask. |
|
|
|
|
|
Args: |
|
|
output: Model output tensor |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Binary segmentation mask |
|
|
""" |
|
|
|
|
|
probs = F.softmax(output, dim=1) |
|
|
pred_mask = torch.argmax(probs, dim=1) |
|
|
|
|
|
return pred_mask.cpu().numpy()[0] |
|
|
|
|
|
def inference_pytorch(model_path, image_path): |
|
|
""" |
|
|
Run inference using PyTorch model. |
|
|
""" |
|
|
|
|
|
model = torch.jit.load(model_path, map_location='cpu') |
|
|
model.eval() |
|
|
|
|
|
|
|
|
input_tensor = preprocess_image(image_path) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor) |
|
|
|
|
|
|
|
|
mask = postprocess_output(output) |
|
|
|
|
|
return mask |
|
|
|
|
|
def inference_onnx(model_path, image_path): |
|
|
""" |
|
|
Run inference using ONNX model. |
|
|
""" |
|
|
|
|
|
session = ort.InferenceSession(model_path) |
|
|
|
|
|
|
|
|
input_tensor = preprocess_image(image_path) |
|
|
input_array = input_tensor.numpy() |
|
|
|
|
|
|
|
|
input_name = session.get_inputs()[0].name |
|
|
output = session.run(None, {input_name: input_array})[0] |
|
|
|
|
|
|
|
|
output_tensor = torch.from_numpy(output) |
|
|
mask = postprocess_output(output_tensor) |
|
|
|
|
|
return mask |
|
|
|
|
|
def save_mask(mask, output_path): |
|
|
"""Save segmentation mask as image.""" |
|
|
|
|
|
mask_image = (mask * 255).astype(np.uint8) |
|
|
cv2.imwrite(output_path, mask_image) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Run inference on card segmentation model') |
|
|
parser.add_argument('--model', type=str, required=True, help='Path to model file') |
|
|
parser.add_argument('--image', type=str, required=True, help='Path to input image') |
|
|
parser.add_argument('--output', type=str, default='output_mask.png', help='Output mask path') |
|
|
parser.add_argument('--format', type=str, choices=['pytorch', 'onnx'], default='onnx', |
|
|
help='Model format') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.format == 'pytorch': |
|
|
mask = inference_pytorch(args.model, args.image) |
|
|
else: |
|
|
mask = inference_onnx(args.model, args.image) |
|
|
|
|
|
|
|
|
save_mask(mask, args.output) |
|
|
print(f"Segmentation mask saved to: {args.output}") |
|
|
|