|
|
|
import argparse |
|
import torch |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
import json |
|
from PIL import Image |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Run person segmentation with YOLO12l-seg model') |
|
parser.add_argument('--model', type=str, default='yolo12l-person-seg.pt', help='Model path') |
|
parser.add_argument('--image', type=str, required=True, help='Image path for inference') |
|
parser.add_argument('--output', type=str, default='output.jpg', help='Output visualization image path') |
|
parser.add_argument('--json', type=str, default='detections.json', help='JSON output file for detection data') |
|
parser.add_argument('--conf', type=float, default=0.5, help='Confidence threshold') |
|
args = parser.parse_args() |
|
|
|
|
|
model = YOLO(args.model) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
print(f"Using CUDA device: {torch.cuda.get_device_name(0)}") |
|
model.to('cuda') |
|
device = 'cuda' |
|
use_half = True |
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
print("Using Apple Silicon MPS") |
|
model.to('mps') |
|
device = 'mps' |
|
use_half = False |
|
else: |
|
print("Using CPU") |
|
device = None |
|
use_half = False |
|
|
|
|
|
try: |
|
img = Image.open(args.image) |
|
img_width, img_height = img.size |
|
print(f"Image dimensions: {img_width}x{img_height}") |
|
except Exception as e: |
|
print(f"Error opening image: {e}") |
|
return |
|
|
|
|
|
if device == 'cuda': |
|
results = model(args.image, classes=0, conf=args.conf, device=device, half=use_half) |
|
elif device == 'mps': |
|
results = model(args.image, classes=0, conf=args.conf, device=device) |
|
else: |
|
results = model(args.image, classes=0, conf=args.conf) |
|
|
|
|
|
detections = [] |
|
visualization_img = cv2.imread(args.image) |
|
|
|
for result in results: |
|
masks = result.masks |
|
boxes = result.boxes |
|
|
|
if boxes is None or len(boxes) == 0: |
|
print("No people detected in the image") |
|
return |
|
|
|
person_count = len(boxes) |
|
print(f"Detected {person_count} people") |
|
|
|
|
|
if masks is not None: |
|
for i, (mask, box) in enumerate(zip(masks.xy, boxes)): |
|
confidence = float(box.conf[0]) |
|
x1, y1, x2, y2 = map(int, box.xyxy[0]) |
|
|
|
|
|
polygon_points = mask.tolist() |
|
|
|
|
|
x_coords = [point[0] for point in polygon_points] |
|
y_coords = [point[1] for point in polygon_points] |
|
min_x, max_x = min(x_coords), max(x_coords) |
|
min_y, max_y = min(y_coords), max(y_coords) |
|
width_pct = (max_x - min_x) / img_width |
|
height_pct = (max_y - min_y) / img_height |
|
|
|
|
|
detection = { |
|
"id": i, |
|
"confidence": confidence, |
|
"box": [x1, y1, x2, y2], |
|
"points": polygon_points, |
|
"width_pct": width_pct, |
|
"height_pct": height_pct, |
|
} |
|
detections.append(detection) |
|
|
|
|
|
cv2.rectangle(visualization_img, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
cv2.putText(visualization_img, f'Person: {confidence:.2f}', (x1, y1 - 10), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) |
|
|
|
|
|
color_mask = np.zeros_like(visualization_img, dtype=np.uint8) |
|
mask_points = np.array(polygon_points, dtype=np.int32) |
|
cv2.fillPoly(color_mask, [mask_points], (0, 0, 255)) |
|
|
|
|
|
visualization_img = cv2.addWeighted(visualization_img, 1.0, color_mask, 0.5, 0) |
|
|
|
|
|
cv2.imwrite(args.output, visualization_img) |
|
print(f"Visualization saved to {args.output}") |
|
|
|
|
|
with open(args.json, 'w') as f: |
|
json.dump({ |
|
"person_count": person_count, |
|
"detections": detections |
|
}, f, indent=4) |
|
print(f"Detection data saved to {args.json}") |
|
|
|
if __name__ == "__main__": |
|
main() |