# Run the setup.sh install script before running this app. import os # os.system("bash setup.sh") import gradio as gr import cv2 from gradio_webrtc import WebRTC import time import threading import numpy as np from time import sleep import spaces from twilio.rest import Client try: import mmcv from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator from mmpose.evaluation.functional import nms from mmpose.registry import VISUALIZERS from mmpose.structures import merge_data_samples from mmpose.utils import adapt_mmdet_pipeline except (ImportError, ModuleNotFoundError): os.system("pip uninstall -y mmpose mmdet mmcv mmengine mmpretrain") os.system("pip install -U openmim") os.system("mim install mmengine mmcv==2.1.0 mmdet==3.3.0 mmpretrain==1.2.0") import mmcv from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator from mmpose.evaluation.functional import nms from mmpose.registry import VISUALIZERS from mmpose.structures import merge_data_samples from mmpose.utils import adapt_mmdet_pipeline import hashlib try: from mmdet.apis import inference_detector, init_detector has_mmdet = True except (ImportError, ModuleNotFoundError): has_mmdet = False DET_CFG = "demo/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py" DET_WEIGHTS = "https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth" POSE_CFG = "configs/body_2d_keypoint/topdown_probmap/coco/td-pm_ProbPose-small_8xb64-210e_coco-256x192.py" POSE_WEIGHTS = "models/ProbPose-s.pth" DEVICE = 'cuda:0' # DEVICE = 'cpu' # Global variables for models print("Initializing MMDetection detector...") det_model = init_detector(DET_CFG, DET_WEIGHTS, device=DEVICE) det_model.cfg = adapt_mmdet_pipeline(det_model.cfg) print("Detector initialized successfully!") print("Initializing MMPose estimator...") pose_model = init_pose_estimator( POSE_CFG, POSE_WEIGHTS, device=DEVICE, cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=True))) ) print("Pose estimator initialized successfully!") pose_model.cfg.visualizer.radius = 4 pose_model.cfg.visualizer.alpha = 0.8 pose_model.cfg.visualizer.line_width = 2 visualizer = VISUALIZERS.build(pose_model.cfg.visualizer) visualizer.set_dataset_meta( pose_model.dataset_meta, skeleton_style='mmpose' ) @spaces.GPU def process_frame(frame, bbox_thr=0.3, nms_thr=0.8, kpt_thr=0.3): """Process a single frame with pose estimation""" global det_model, pose_model, visualizer processing_start = time.time() # Mirror the frame frame = frame[:, ::-1, :] # Flip horizontally for webcam mirroring det_result = inference_detector(det_model, frame) pred_instance = det_result.pred_instances.cpu().numpy() bboxes = np.concatenate( (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > bbox_thr)] # Sort bboxes by confidence score (column 4) in descending order if len(bboxes) > 0: order = np.argsort(bboxes[:, 4])[::-1] bboxes = bboxes[order[0], :4].reshape((1, -1)) else: # No person detected, return original frame return frame visualizer.set_image(frame) # predict keypoints pose_start = time.time() pose_results = inference_topdown(pose_model, frame, bboxes) data_samples = merge_data_samples(pose_results) # Visualize results visualization_start = time.time() visualizer.add_datasample( 'result', frame, data_sample=data_samples, draw_gt=False, draw_heatmap=False, draw_bbox=True, show_kpt_idx=False, show=False, kpt_thr=kpt_thr) stop_time = time.time() return visualizer.get_image() # WebRTC configuration for webcam streaming client = Client( os.getenv("TWILIO_ACCOUNT_SID"), os.getenv("TWILIO_AUTH_TOKEN") ) token = client.tokens.create() # includes token.iceServers rtc_configuration = {"iceServers": token.ice_servers} webcam_constraints = { "video": { "width": {"exact": 320}, "height": {"exact": 240}, "sampleRate": {"ideal": 2, "max": 5} } } class AsyncFrameProcessor: """ Asynchronous frame processor that handles real-time video stream processing. Maintains single-slot input and output queues to process only the latest frame, preventing queue buildup and ensuring real-time performance. """ def __init__(self, processing_delay=0.5, startup_delay=0.0): """ Initialize the async frame processor. Args: processing_delay (float): Simulated processing time in seconds startup_delay (float): Delay before processing starts """ self.processing_delay = processing_delay self.startup_delay = startup_delay self.first_call_time = None self.frame_counter = 0 # Thread-safe single-slot queues self.input_lock = threading.Lock() self.output_lock = threading.Lock() self.latest_input_frame = None self.latest_output_frame = None # Threading components self.processing_thread = None self.stop_event = threading.Event() self.new_frame_signal = threading.Event() # Start background processing self._start_processing_thread() def _start_processing_thread(self): """Start the background processing thread""" if self.processing_thread is None or not self.processing_thread.is_alive(): self.stop_event.clear() self.processing_thread = threading.Thread(target=self._processing_worker, daemon=True) self.processing_thread.start() def _processing_worker(self): """Background thread that processes the latest frame""" while not self.stop_event.is_set(): # Wait for a new frame to be available if self.new_frame_signal.wait(timeout=1.0): self.new_frame_signal.clear() print("New frame received, starting processing...") # Get the latest input frame with self.input_lock: if self.latest_input_frame is not None: frame_to_process = self.latest_input_frame.copy() frame_number = self.frame_counter process_unique_hash = hashlib.md5(frame_to_process.tobytes()).hexdigest() # print(f"Processing unique hash: {process_unique_hash}") else: continue print(f"Processing frame number: {frame_number}") # Process the frame using the global models processed_frame = process_frame(frame_to_process) # Write frame number in the top left corner processed_frame = cv2.putText( processed_frame, "{:d}".format(frame_number), [50, 50], fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2, ) print(f"Frame {frame_number} processed.") # Store the processed result with self.output_lock: self.latest_output_frame = processed_frame def process(self, frame): """ Main processing function called by Gradio stream. Stores incoming frame and returns latest processed result. """ current_time = time.time() if self.first_call_time is None: self.first_call_time = current_time # Store the new frame in the input slot (replacing any existing frame) with self.input_lock: self.latest_input_frame = frame.copy() self.frame_counter += 1 input_unique_hash = hashlib.md5(frame.tobytes()).hexdigest() # print(f"Input unique hash: {input_unique_hash}") # Signal that a new frame is available for processing self.new_frame_signal.set() # Return the latest processed output, or original frame if no processing done yet with self.output_lock: if self.latest_output_frame is not None: output_unique_hash = hashlib.md5(self.latest_output_frame.tobytes()).hexdigest() # print(f"Output unique hash: {output_unique_hash}") return self.latest_output_frame else: # Add indicator that this is unprocessed temp_frame = frame.copy() cv2.putText( temp_frame, f"Waiting... {self.frame_counter}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), # Red for unprocessed frames 2, ) return temp_frame def stop(self): """Stop the processing thread""" self.stop_event.set() if self.processing_thread and self.processing_thread.is_alive(): self.processing_thread.join(timeout=2.0) # CSS for styling the Gradio interface css = """.my-group {max-width: 600px !important; max-height: 600 !important;} .my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" # Initialize the asynchronous frame processor frame_processor = AsyncFrameProcessor(processing_delay=0.5) # Create Gradio interface with gr.Blocks(css=css) as demo: gr.HTML( """

ProbPose Webcam Demo (CVPR 2025)

""" ) gr.HTML( """

See https://MiraPurkrabek.github.io/ProbPose/ for details.

""" ) # with gr.Column(elem_classes=["my-column"]): # with gr.Group(elem_classes=["my-group"]): # webcam_stream = WebRTC( # label="Webcam Stream", # rtc_configuration=rtc_configuration, # track_constraints=webcam_constraints, # mirror_webcam=True, # ) # # Stream processing: connects webcam input to frame processor # webcam_stream.stream( # fn=frame_processor.process, # inputs=[webcam_stream], # outputs=[webcam_stream], # time_limit=120, # Limit processing time to 120 seconds # ) with gr.Row(): with gr.Column(): input_img = gr.Image(sources=["webcam"], type="numpy") with gr.Column(): output_img = gr.Image(streaming=True) dep = input_img.stream(frame_processor.process, [input_img], [output_img], time_limit=30, concurrency_limit=30) if __name__ == "__main__": demo.launch( # server_name="0.0.0.0", # server_port=17860, # share=True )