ProbPose-demo / app.py
Miroslav Purkrabek
init models first
4241c11
# Run the setup.sh install script before running this app.
import os
# os.system("bash setup.sh")
import gradio as gr
import cv2
from gradio_webrtc import WebRTC
import time
import threading
import numpy as np
from time import sleep
import spaces
from twilio.rest import Client
try:
import mmcv
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline
except (ImportError, ModuleNotFoundError):
os.system("pip uninstall -y mmpose mmdet mmcv mmengine mmpretrain")
os.system("pip install -U openmim")
os.system("mim install mmengine mmcv==2.1.0 mmdet==3.3.0 mmpretrain==1.2.0")
import mmcv
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline
import hashlib
try:
from mmdet.apis import inference_detector, init_detector
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
DET_CFG = "demo/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py"
DET_WEIGHTS = "https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
POSE_CFG = "configs/body_2d_keypoint/topdown_probmap/coco/td-pm_ProbPose-small_8xb64-210e_coco-256x192.py"
POSE_WEIGHTS = "models/ProbPose-s.pth"
DEVICE = 'cuda:0'
# DEVICE = 'cpu'
# Global variables for models
print("Initializing MMDetection detector...")
det_model = init_detector(DET_CFG, DET_WEIGHTS, device=DEVICE)
det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
print("Detector initialized successfully!")
print("Initializing MMPose estimator...")
pose_model = init_pose_estimator(
POSE_CFG,
POSE_WEIGHTS,
device=DEVICE,
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=True)))
)
print("Pose estimator initialized successfully!")
pose_model.cfg.visualizer.radius = 4
pose_model.cfg.visualizer.alpha = 0.8
pose_model.cfg.visualizer.line_width = 2
visualizer = VISUALIZERS.build(pose_model.cfg.visualizer)
visualizer.set_dataset_meta(
pose_model.dataset_meta, skeleton_style='mmpose'
)
@spaces.GPU
def process_frame(frame, bbox_thr=0.3, nms_thr=0.8, kpt_thr=0.3):
"""Process a single frame with pose estimation"""
global det_model, pose_model, visualizer
processing_start = time.time()
# Mirror the frame
frame = frame[:, ::-1, :] # Flip horizontally for webcam mirroring
det_result = inference_detector(det_model, frame)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate(
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
bboxes = bboxes[np.logical_and(pred_instance.labels == 0,
pred_instance.scores > bbox_thr)]
# Sort bboxes by confidence score (column 4) in descending order
if len(bboxes) > 0:
order = np.argsort(bboxes[:, 4])[::-1]
bboxes = bboxes[order[0], :4].reshape((1, -1))
else:
# No person detected, return original frame
return frame
visualizer.set_image(frame)
# predict keypoints
pose_start = time.time()
pose_results = inference_topdown(pose_model, frame, bboxes)
data_samples = merge_data_samples(pose_results)
# Visualize results
visualization_start = time.time()
visualizer.add_datasample(
'result',
frame,
data_sample=data_samples,
draw_gt=False,
draw_heatmap=False,
draw_bbox=True,
show_kpt_idx=False,
show=False,
kpt_thr=kpt_thr)
stop_time = time.time()
return visualizer.get_image()
# WebRTC configuration for webcam streaming
client = Client(
os.getenv("TWILIO_ACCOUNT_SID"),
os.getenv("TWILIO_AUTH_TOKEN")
)
token = client.tokens.create() # includes token.iceServers
rtc_configuration = {"iceServers": token.ice_servers}
webcam_constraints = {
"video": {
"width": {"exact": 320},
"height": {"exact": 240},
"sampleRate": {"ideal": 2, "max": 5}
}
}
class AsyncFrameProcessor:
"""
Asynchronous frame processor that handles real-time video stream processing.
Maintains single-slot input and output queues to process only the latest frame,
preventing queue buildup and ensuring real-time performance.
"""
def __init__(self, processing_delay=0.5, startup_delay=0.0):
"""
Initialize the async frame processor.
Args:
processing_delay (float): Simulated processing time in seconds
startup_delay (float): Delay before processing starts
"""
self.processing_delay = processing_delay
self.startup_delay = startup_delay
self.first_call_time = None
self.frame_counter = 0
# Thread-safe single-slot queues
self.input_lock = threading.Lock()
self.output_lock = threading.Lock()
self.latest_input_frame = None
self.latest_output_frame = None
# Threading components
self.processing_thread = None
self.stop_event = threading.Event()
self.new_frame_signal = threading.Event()
# Start background processing
self._start_processing_thread()
def _start_processing_thread(self):
"""Start the background processing thread"""
if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_event.clear()
self.processing_thread = threading.Thread(target=self._processing_worker, daemon=True)
self.processing_thread.start()
def _processing_worker(self):
"""Background thread that processes the latest frame"""
while not self.stop_event.is_set():
# Wait for a new frame to be available
if self.new_frame_signal.wait(timeout=1.0):
self.new_frame_signal.clear()
print("New frame received, starting processing...")
# Get the latest input frame
with self.input_lock:
if self.latest_input_frame is not None:
frame_to_process = self.latest_input_frame.copy()
frame_number = self.frame_counter
process_unique_hash = hashlib.md5(frame_to_process.tobytes()).hexdigest()
# print(f"Processing unique hash: {process_unique_hash}")
else:
continue
print(f"Processing frame number: {frame_number}")
# Process the frame using the global models
processed_frame = process_frame(frame_to_process)
# Write frame number in the top left corner
processed_frame = cv2.putText(
processed_frame,
"{:d}".format(frame_number),
[50, 50],
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=1,
color=(0, 0, 255),
thickness=2,
)
print(f"Frame {frame_number} processed.")
# Store the processed result
with self.output_lock:
self.latest_output_frame = processed_frame
def process(self, frame):
"""
Main processing function called by Gradio stream.
Stores incoming frame and returns latest processed result.
"""
current_time = time.time()
if self.first_call_time is None:
self.first_call_time = current_time
# Store the new frame in the input slot (replacing any existing frame)
with self.input_lock:
self.latest_input_frame = frame.copy()
self.frame_counter += 1
input_unique_hash = hashlib.md5(frame.tobytes()).hexdigest()
# print(f"Input unique hash: {input_unique_hash}")
# Signal that a new frame is available for processing
self.new_frame_signal.set()
# Return the latest processed output, or original frame if no processing done yet
with self.output_lock:
if self.latest_output_frame is not None:
output_unique_hash = hashlib.md5(self.latest_output_frame.tobytes()).hexdigest()
# print(f"Output unique hash: {output_unique_hash}")
return self.latest_output_frame
else:
# Add indicator that this is unprocessed
temp_frame = frame.copy()
cv2.putText(
temp_frame,
f"Waiting... {self.frame_counter}",
(50, 50),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 0, 0), # Red for unprocessed frames
2,
)
return temp_frame
def stop(self):
"""Stop the processing thread"""
self.stop_event.set()
if self.processing_thread and self.processing_thread.is_alive():
self.processing_thread.join(timeout=2.0)
# CSS for styling the Gradio interface
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
# Initialize the asynchronous frame processor
frame_processor = AsyncFrameProcessor(processing_delay=0.5)
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
ProbPose Webcam Demo (CVPR 2025)
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
See <a href="https://MiraPurkrabek.github.io/ProbPose/" target="_blank">https://MiraPurkrabek.github.io/ProbPose/</a> for details.
</h3>
"""
)
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# webcam_stream = WebRTC(
# label="Webcam Stream",
# rtc_configuration=rtc_configuration,
# track_constraints=webcam_constraints,
# mirror_webcam=True,
# )
# # Stream processing: connects webcam input to frame processor
# webcam_stream.stream(
# fn=frame_processor.process,
# inputs=[webcam_stream],
# outputs=[webcam_stream],
# time_limit=120, # Limit processing time to 120 seconds
# )
with gr.Row():
with gr.Column():
input_img = gr.Image(sources=["webcam"], type="numpy")
with gr.Column():
output_img = gr.Image(streaming=True)
dep = input_img.stream(frame_processor.process, [input_img], [output_img],
time_limit=30, concurrency_limit=30)
if __name__ == "__main__":
demo.launch(
# server_name="0.0.0.0",
# server_port=17860,
# share=True
)