Spaces:
Sleeping
Sleeping
# Run the setup.sh install script before running this app. | |
import os | |
# os.system("bash setup.sh") | |
import gradio as gr | |
import cv2 | |
from gradio_webrtc import WebRTC | |
import time | |
import threading | |
import numpy as np | |
from time import sleep | |
import spaces | |
from twilio.rest import Client | |
try: | |
import mmcv | |
from mmpose.apis import inference_topdown | |
from mmpose.apis import init_model as init_pose_estimator | |
from mmpose.evaluation.functional import nms | |
from mmpose.registry import VISUALIZERS | |
from mmpose.structures import merge_data_samples | |
from mmpose.utils import adapt_mmdet_pipeline | |
except (ImportError, ModuleNotFoundError): | |
os.system("pip uninstall -y mmpose mmdet mmcv mmengine mmpretrain") | |
os.system("pip install -U openmim") | |
os.system("mim install mmengine mmcv==2.1.0 mmdet==3.3.0 mmpretrain==1.2.0") | |
import mmcv | |
from mmpose.apis import inference_topdown | |
from mmpose.apis import init_model as init_pose_estimator | |
from mmpose.evaluation.functional import nms | |
from mmpose.registry import VISUALIZERS | |
from mmpose.structures import merge_data_samples | |
from mmpose.utils import adapt_mmdet_pipeline | |
import hashlib | |
try: | |
from mmdet.apis import inference_detector, init_detector | |
has_mmdet = True | |
except (ImportError, ModuleNotFoundError): | |
has_mmdet = False | |
DET_CFG = "demo/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py" | |
DET_WEIGHTS = "https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth" | |
POSE_CFG = "configs/body_2d_keypoint/topdown_probmap/coco/td-pm_ProbPose-small_8xb64-210e_coco-256x192.py" | |
POSE_WEIGHTS = "models/ProbPose-s.pth" | |
DEVICE = 'cuda:0' | |
# DEVICE = 'cpu' | |
# Global variables for models | |
print("Initializing MMDetection detector...") | |
det_model = init_detector(DET_CFG, DET_WEIGHTS, device=DEVICE) | |
det_model.cfg = adapt_mmdet_pipeline(det_model.cfg) | |
print("Detector initialized successfully!") | |
print("Initializing MMPose estimator...") | |
pose_model = init_pose_estimator( | |
POSE_CFG, | |
POSE_WEIGHTS, | |
device=DEVICE, | |
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=True))) | |
) | |
print("Pose estimator initialized successfully!") | |
pose_model.cfg.visualizer.radius = 4 | |
pose_model.cfg.visualizer.alpha = 0.8 | |
pose_model.cfg.visualizer.line_width = 2 | |
visualizer = VISUALIZERS.build(pose_model.cfg.visualizer) | |
visualizer.set_dataset_meta( | |
pose_model.dataset_meta, skeleton_style='mmpose' | |
) | |
def process_frame(frame, bbox_thr=0.3, nms_thr=0.8, kpt_thr=0.3): | |
"""Process a single frame with pose estimation""" | |
global det_model, pose_model, visualizer | |
processing_start = time.time() | |
# Mirror the frame | |
frame = frame[:, ::-1, :] # Flip horizontally for webcam mirroring | |
det_result = inference_detector(det_model, frame) | |
pred_instance = det_result.pred_instances.cpu().numpy() | |
bboxes = np.concatenate( | |
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) | |
bboxes = bboxes[np.logical_and(pred_instance.labels == 0, | |
pred_instance.scores > bbox_thr)] | |
# Sort bboxes by confidence score (column 4) in descending order | |
if len(bboxes) > 0: | |
order = np.argsort(bboxes[:, 4])[::-1] | |
bboxes = bboxes[order[0], :4].reshape((1, -1)) | |
else: | |
# No person detected, return original frame | |
return frame | |
visualizer.set_image(frame) | |
# predict keypoints | |
pose_start = time.time() | |
pose_results = inference_topdown(pose_model, frame, bboxes) | |
data_samples = merge_data_samples(pose_results) | |
# Visualize results | |
visualization_start = time.time() | |
visualizer.add_datasample( | |
'result', | |
frame, | |
data_sample=data_samples, | |
draw_gt=False, | |
draw_heatmap=False, | |
draw_bbox=True, | |
show_kpt_idx=False, | |
show=False, | |
kpt_thr=kpt_thr) | |
stop_time = time.time() | |
return visualizer.get_image() | |
# WebRTC configuration for webcam streaming | |
client = Client( | |
os.getenv("TWILIO_ACCOUNT_SID"), | |
os.getenv("TWILIO_AUTH_TOKEN") | |
) | |
token = client.tokens.create() # includes token.iceServers | |
rtc_configuration = {"iceServers": token.ice_servers} | |
webcam_constraints = { | |
"video": { | |
"width": {"exact": 320}, | |
"height": {"exact": 240}, | |
"sampleRate": {"ideal": 2, "max": 5} | |
} | |
} | |
class AsyncFrameProcessor: | |
""" | |
Asynchronous frame processor that handles real-time video stream processing. | |
Maintains single-slot input and output queues to process only the latest frame, | |
preventing queue buildup and ensuring real-time performance. | |
""" | |
def __init__(self, processing_delay=0.5, startup_delay=0.0): | |
""" | |
Initialize the async frame processor. | |
Args: | |
processing_delay (float): Simulated processing time in seconds | |
startup_delay (float): Delay before processing starts | |
""" | |
self.processing_delay = processing_delay | |
self.startup_delay = startup_delay | |
self.first_call_time = None | |
self.frame_counter = 0 | |
# Thread-safe single-slot queues | |
self.input_lock = threading.Lock() | |
self.output_lock = threading.Lock() | |
self.latest_input_frame = None | |
self.latest_output_frame = None | |
# Threading components | |
self.processing_thread = None | |
self.stop_event = threading.Event() | |
self.new_frame_signal = threading.Event() | |
# Start background processing | |
self._start_processing_thread() | |
def _start_processing_thread(self): | |
"""Start the background processing thread""" | |
if self.processing_thread is None or not self.processing_thread.is_alive(): | |
self.stop_event.clear() | |
self.processing_thread = threading.Thread(target=self._processing_worker, daemon=True) | |
self.processing_thread.start() | |
def _processing_worker(self): | |
"""Background thread that processes the latest frame""" | |
while not self.stop_event.is_set(): | |
# Wait for a new frame to be available | |
if self.new_frame_signal.wait(timeout=1.0): | |
self.new_frame_signal.clear() | |
print("New frame received, starting processing...") | |
# Get the latest input frame | |
with self.input_lock: | |
if self.latest_input_frame is not None: | |
frame_to_process = self.latest_input_frame.copy() | |
frame_number = self.frame_counter | |
process_unique_hash = hashlib.md5(frame_to_process.tobytes()).hexdigest() | |
# print(f"Processing unique hash: {process_unique_hash}") | |
else: | |
continue | |
print(f"Processing frame number: {frame_number}") | |
# Process the frame using the global models | |
processed_frame = process_frame(frame_to_process) | |
# Write frame number in the top left corner | |
processed_frame = cv2.putText( | |
processed_frame, | |
"{:d}".format(frame_number), | |
[50, 50], | |
fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
fontScale=1, | |
color=(0, 0, 255), | |
thickness=2, | |
) | |
print(f"Frame {frame_number} processed.") | |
# Store the processed result | |
with self.output_lock: | |
self.latest_output_frame = processed_frame | |
def process(self, frame): | |
""" | |
Main processing function called by Gradio stream. | |
Stores incoming frame and returns latest processed result. | |
""" | |
current_time = time.time() | |
if self.first_call_time is None: | |
self.first_call_time = current_time | |
# Store the new frame in the input slot (replacing any existing frame) | |
with self.input_lock: | |
self.latest_input_frame = frame.copy() | |
self.frame_counter += 1 | |
input_unique_hash = hashlib.md5(frame.tobytes()).hexdigest() | |
# print(f"Input unique hash: {input_unique_hash}") | |
# Signal that a new frame is available for processing | |
self.new_frame_signal.set() | |
# Return the latest processed output, or original frame if no processing done yet | |
with self.output_lock: | |
if self.latest_output_frame is not None: | |
output_unique_hash = hashlib.md5(self.latest_output_frame.tobytes()).hexdigest() | |
# print(f"Output unique hash: {output_unique_hash}") | |
return self.latest_output_frame | |
else: | |
# Add indicator that this is unprocessed | |
temp_frame = frame.copy() | |
cv2.putText( | |
temp_frame, | |
f"Waiting... {self.frame_counter}", | |
(50, 50), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, | |
(255, 0, 0), # Red for unprocessed frames | |
2, | |
) | |
return temp_frame | |
def stop(self): | |
"""Stop the processing thread""" | |
self.stop_event.set() | |
if self.processing_thread and self.processing_thread.is_alive(): | |
self.processing_thread.join(timeout=2.0) | |
# CSS for styling the Gradio interface | |
css = """.my-group {max-width: 600px !important; max-height: 600 !important;} | |
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" | |
# Initialize the asynchronous frame processor | |
frame_processor = AsyncFrameProcessor(processing_delay=0.5) | |
# Create Gradio interface | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
""" | |
<h1 style='text-align: center'> | |
ProbPose Webcam Demo (CVPR 2025) | |
</h1> | |
""" | |
) | |
gr.HTML( | |
""" | |
<h3 style='text-align: center'> | |
See <a href="https://MiraPurkrabek.github.io/ProbPose/" target="_blank">https://MiraPurkrabek.github.io/ProbPose/</a> for details. | |
</h3> | |
""" | |
) | |
# with gr.Column(elem_classes=["my-column"]): | |
# with gr.Group(elem_classes=["my-group"]): | |
# webcam_stream = WebRTC( | |
# label="Webcam Stream", | |
# rtc_configuration=rtc_configuration, | |
# track_constraints=webcam_constraints, | |
# mirror_webcam=True, | |
# ) | |
# # Stream processing: connects webcam input to frame processor | |
# webcam_stream.stream( | |
# fn=frame_processor.process, | |
# inputs=[webcam_stream], | |
# outputs=[webcam_stream], | |
# time_limit=120, # Limit processing time to 120 seconds | |
# ) | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(sources=["webcam"], type="numpy") | |
with gr.Column(): | |
output_img = gr.Image(streaming=True) | |
dep = input_img.stream(frame_processor.process, [input_img], [output_img], | |
time_limit=30, concurrency_limit=30) | |
if __name__ == "__main__": | |
demo.launch( | |
# server_name="0.0.0.0", | |
# server_port=17860, | |
# share=True | |
) |