""" Example inference script for card_segmentation model. """ import torch import torch.nn.functional as F import cv2 import numpy as np from PIL import Image import onnxruntime as ort def preprocess_image(image_path, target_size=(320, 240)): """ Preprocess image for model inference. Args: image_path (str): Path to input image target_size (tuple): Target image size (H, W) Returns: torch.Tensor: Preprocessed image tensor """ # Load image image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Resize image = cv2.resize(image, (target_size[1], target_size[0])) # Normalize image = image.astype(np.float32) / 255.0 image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225]) # Convert to tensor and add batch dimension image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) return image_tensor def postprocess_output(output): """ Postprocess model output to get segmentation mask. Args: output: Model output tensor Returns: np.ndarray: Binary segmentation mask """ # Apply softmax and get predictions probs = F.softmax(output, dim=1) pred_mask = torch.argmax(probs, dim=1) return pred_mask.cpu().numpy()[0] def inference_pytorch(model_path, image_path): """ Run inference using PyTorch model. """ # Load model model = torch.jit.load(model_path, map_location='cpu') model.eval() # Preprocess image input_tensor = preprocess_image(image_path) # Run inference with torch.no_grad(): output = model(input_tensor) # Postprocess mask = postprocess_output(output) return mask def inference_onnx(model_path, image_path): """ Run inference using ONNX model. """ # Load ONNX model session = ort.InferenceSession(model_path) # Preprocess image input_tensor = preprocess_image(image_path) input_array = input_tensor.numpy() # Run inference input_name = session.get_inputs()[0].name output = session.run(None, {input_name: input_array})[0] # Postprocess output_tensor = torch.from_numpy(output) mask = postprocess_output(output_tensor) return mask def save_mask(mask, output_path): """Save segmentation mask as image.""" # Convert to 0-255 range mask_image = (mask * 255).astype(np.uint8) cv2.imwrite(output_path, mask_image) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Run inference on card segmentation model') parser.add_argument('--model', type=str, required=True, help='Path to model file') parser.add_argument('--image', type=str, required=True, help='Path to input image') parser.add_argument('--output', type=str, default='output_mask.png', help='Output mask path') parser.add_argument('--format', type=str, choices=['pytorch', 'onnx'], default='onnx', help='Model format') args = parser.parse_args() # Run inference if args.format == 'pytorch': mask = inference_pytorch(args.model, args.image) else: mask = inference_onnx(args.model, args.image) # Save result save_mask(mask, args.output) print(f"Segmentation mask saved to: {args.output}")