# pip install -q rfdetr==1.2.1 supervision==0.26.1 # RF-DETR video processing for threat detection. # Inference time depends on frame resolution (e.g., ~50 ms/frame on GPU for 640×640). import numpy as np import supervision as sv import torch import requests from PIL import Image import os import cv2 from tqdm import tqdm import time from rfdetr import RFDETRNano THREAT_CLASSES = { 1: "Gun", 2: "Explosive", 3: "Grenade", 4: "Knife" } # Enable GPU if available if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") # print(f"CUDA Version: {torch.version.cuda}") # print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") # Optimize for batch processing torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False else: print("CUDA not available, using CPU") # Configuration INPUT_VIDEO = "test_video.mp4" base, ext = os.path.splitext(INPUT_VIDEO) OUTPUT_VIDEO = f"{base}_detr{ext}" THRESHOLD = 0.5 BATCH_SIZE = 32 # Auto-adjust batch size based on GPU memory if torch.cuda.is_available(): gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3 print(f"Using batch size: {BATCH_SIZE}") # Download weights weights_url = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth" weights_filename = "checkpoint_best_total.pth" if not os.path.exists(weights_filename): print(f"Downloading weights from {weights_url}") response = requests.get(weights_url, stream=True) response.raise_for_status() with open(weights_filename, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print("Download complete.") print("Loading model...") model = RFDETRNano(resolution=640, pretrain_weights=weights_filename) model.optimize_for_inference() # Setup annotators color = sv.ColorPalette.from_hex([ "#1E90FF", "#32CD32", "#FF0000", "#FF8C00" ]) bbox_annotator = sv.BoxAnnotator(color=color, thickness=3) label_annotator = sv.LabelAnnotator( color=color, text_color=sv.Color.BLACK, text_scale=1.0, text_thickness=2, smart_position=True ) def process_frame_batch(frames): """Process a batch of frames for better GPU utilization""" batch_results = [] # Convert all frames to PIL images pil_images = [] for frame in frames: rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(rgb_frame) pil_images.append(pil_image) # Process each image in the batch (RF-DETR processes them efficiently) batch_detections = [] for pil_image in pil_images: detections = model.predict(pil_image, threshold=THRESHOLD) batch_detections.append(detections) # Annotate all images in the batch annotated_frames = [] for pil_image, detections in zip(pil_images, batch_detections): # Create labels labels = [] for class_id, confidence in zip(detections.class_id, detections.confidence): class_name = THREAT_CLASSES.get(class_id, f"unknown_class_{class_id}") labels.append(f"{class_name} {confidence:.2f}") # Annotate annotated_pil = pil_image.copy() annotated_pil = bbox_annotator.annotate(annotated_pil, detections) annotated_pil = label_annotator.annotate(annotated_pil, detections, labels) # Convert back to BGR annotated_frame = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR) annotated_frames.append(annotated_frame) return annotated_frames, batch_detections # Open video cap = cv2.VideoCapture(INPUT_VIDEO) if not cap.isOpened(): print(f"Error: Could not open video file {INPUT_VIDEO}") exit() # Get video properties fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) print(f"Video: {width}x{height}, {fps} FPS, {total_frames} frames") print(f"Processing in batches of {BATCH_SIZE} frames") # Setup video writer fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (width, height)) # Batch processing print("Processing video with batch inference...") frame_buffer = [] total_detections = 0 processed_frames = 0 processing_times = [] with tqdm(total=total_frames, desc="Batch processing") as pbar: while True: ret, frame = cap.read() if not ret: # Process remaining frames in buffer if frame_buffer: start_time = time.time() annotated_frames, batch_detections = process_frame_batch(frame_buffer) processing_time = time.time() - start_time processing_times.append(processing_time) # Write remaining frames for annotated_frame, detections in zip(annotated_frames, batch_detections): out.write(annotated_frame) total_detections += len(detections) processed_frames += len(frame_buffer) pbar.update(len(frame_buffer)) break # Add frame to buffer frame_buffer.append(frame) # Process when buffer is full if len(frame_buffer) >= BATCH_SIZE: start_time = time.time() # Process batch annotated_frames, batch_detections = process_frame_batch(frame_buffer) processing_time = time.time() - start_time processing_times.append(processing_time) # Write frames batch_threats = 0 for annotated_frame, detections in zip(annotated_frames, batch_detections): out.write(annotated_frame) batch_threats += len(detections) total_detections += len(detections) processed_frames += len(frame_buffer) # Update progress batch_fps = len(frame_buffer) / processing_time if processing_time > 0 else 0 pbar.set_postfix({ 'Batch FPS': f"{batch_fps:.1f}", 'Threats': batch_threats, 'Total': total_detections }) pbar.update(len(frame_buffer)) # Clear buffer frame_buffer = [] # Clear GPU cache every 10 batches if torch.cuda.is_available() and processed_frames % (BATCH_SIZE * 10) == 0: torch.cuda.empty_cache() # Cleanup cap.release() out.release() if torch.cuda.is_available(): torch.cuda.empty_cache() # Performance summary total_time = sum(processing_times) avg_fps = processed_frames / total_time if total_time > 0 else 0 speedup = avg_fps / fps if fps > 0 else 0 print(f"Output: {OUTPUT_VIDEO}") print(f"Stats:") print(f" • Processed: {processed_frames} frames") print(f" • Detections: {total_detections}") print(f" • Batch size: {BATCH_SIZE}") print(f" • Average speed: {avg_fps:.1f} FPS") print(f" • Speedup: {speedup:.1f}x real-time") print(f" • Processing time: {total_time:.1f}s")