Spaces:
Sleeping
Sleeping
File size: 11,447 Bytes
4f5aed8 3cd9b36 a8ae8d5 4f5aed8 73e5f29 f2550ff 73e5f29 3cd9b36 73e5f29 50e6e4b 73e5f29 ac2dad3 3cd9b36 ebb18ad 3cd9b36 ebb18ad 3cd9b36 73e5f29 472855f f41366b 4241c11 472855f 4241c11 f41366b 4241c11 f41366b 73e5f29 ac2dad3 73e5f29 b9d8842 73e5f29 b9d8842 4241c11 73e5f29 b9d8842 73e5f29 4b0b7fc 73e5f29 50e6e4b 73e5f29 50e6e4b 73e5f29 bf754bc 18ee5d7 bf754bc 93ec95f bf754bc 18ee5d7 73e5f29 4241c11 73e5f29 30aaf29 73e5f29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
# Run the setup.sh install script before running this app.
import os
# os.system("bash setup.sh")
import gradio as gr
import cv2
from gradio_webrtc import WebRTC
import time
import threading
import numpy as np
from time import sleep
import spaces
from twilio.rest import Client
try:
import mmcv
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline
except (ImportError, ModuleNotFoundError):
os.system("pip uninstall -y mmpose mmdet mmcv mmengine mmpretrain")
os.system("pip install -U openmim")
os.system("mim install mmengine mmcv==2.1.0 mmdet==3.3.0 mmpretrain==1.2.0")
import mmcv
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.utils import adapt_mmdet_pipeline
import hashlib
try:
from mmdet.apis import inference_detector, init_detector
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
DET_CFG = "demo/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py"
DET_WEIGHTS = "https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
POSE_CFG = "configs/body_2d_keypoint/topdown_probmap/coco/td-pm_ProbPose-small_8xb64-210e_coco-256x192.py"
POSE_WEIGHTS = "models/ProbPose-s.pth"
DEVICE = 'cuda:0'
# DEVICE = 'cpu'
# Global variables for models
print("Initializing MMDetection detector...")
det_model = init_detector(DET_CFG, DET_WEIGHTS, device=DEVICE)
det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
print("Detector initialized successfully!")
print("Initializing MMPose estimator...")
pose_model = init_pose_estimator(
POSE_CFG,
POSE_WEIGHTS,
device=DEVICE,
cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=True)))
)
print("Pose estimator initialized successfully!")
pose_model.cfg.visualizer.radius = 4
pose_model.cfg.visualizer.alpha = 0.8
pose_model.cfg.visualizer.line_width = 2
visualizer = VISUALIZERS.build(pose_model.cfg.visualizer)
visualizer.set_dataset_meta(
pose_model.dataset_meta, skeleton_style='mmpose'
)
@spaces.GPU
def process_frame(frame, bbox_thr=0.3, nms_thr=0.8, kpt_thr=0.3):
"""Process a single frame with pose estimation"""
global det_model, pose_model, visualizer
processing_start = time.time()
# Mirror the frame
frame = frame[:, ::-1, :] # Flip horizontally for webcam mirroring
det_result = inference_detector(det_model, frame)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate(
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
bboxes = bboxes[np.logical_and(pred_instance.labels == 0,
pred_instance.scores > bbox_thr)]
# Sort bboxes by confidence score (column 4) in descending order
if len(bboxes) > 0:
order = np.argsort(bboxes[:, 4])[::-1]
bboxes = bboxes[order[0], :4].reshape((1, -1))
else:
# No person detected, return original frame
return frame
visualizer.set_image(frame)
# predict keypoints
pose_start = time.time()
pose_results = inference_topdown(pose_model, frame, bboxes)
data_samples = merge_data_samples(pose_results)
# Visualize results
visualization_start = time.time()
visualizer.add_datasample(
'result',
frame,
data_sample=data_samples,
draw_gt=False,
draw_heatmap=False,
draw_bbox=True,
show_kpt_idx=False,
show=False,
kpt_thr=kpt_thr)
stop_time = time.time()
return visualizer.get_image()
# WebRTC configuration for webcam streaming
client = Client(
os.getenv("TWILIO_ACCOUNT_SID"),
os.getenv("TWILIO_AUTH_TOKEN")
)
token = client.tokens.create() # includes token.iceServers
rtc_configuration = {"iceServers": token.ice_servers}
webcam_constraints = {
"video": {
"width": {"exact": 320},
"height": {"exact": 240},
"sampleRate": {"ideal": 2, "max": 5}
}
}
class AsyncFrameProcessor:
"""
Asynchronous frame processor that handles real-time video stream processing.
Maintains single-slot input and output queues to process only the latest frame,
preventing queue buildup and ensuring real-time performance.
"""
def __init__(self, processing_delay=0.5, startup_delay=0.0):
"""
Initialize the async frame processor.
Args:
processing_delay (float): Simulated processing time in seconds
startup_delay (float): Delay before processing starts
"""
self.processing_delay = processing_delay
self.startup_delay = startup_delay
self.first_call_time = None
self.frame_counter = 0
# Thread-safe single-slot queues
self.input_lock = threading.Lock()
self.output_lock = threading.Lock()
self.latest_input_frame = None
self.latest_output_frame = None
# Threading components
self.processing_thread = None
self.stop_event = threading.Event()
self.new_frame_signal = threading.Event()
# Start background processing
self._start_processing_thread()
def _start_processing_thread(self):
"""Start the background processing thread"""
if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_event.clear()
self.processing_thread = threading.Thread(target=self._processing_worker, daemon=True)
self.processing_thread.start()
def _processing_worker(self):
"""Background thread that processes the latest frame"""
while not self.stop_event.is_set():
# Wait for a new frame to be available
if self.new_frame_signal.wait(timeout=1.0):
self.new_frame_signal.clear()
print("New frame received, starting processing...")
# Get the latest input frame
with self.input_lock:
if self.latest_input_frame is not None:
frame_to_process = self.latest_input_frame.copy()
frame_number = self.frame_counter
process_unique_hash = hashlib.md5(frame_to_process.tobytes()).hexdigest()
# print(f"Processing unique hash: {process_unique_hash}")
else:
continue
print(f"Processing frame number: {frame_number}")
# Process the frame using the global models
processed_frame = process_frame(frame_to_process)
# Write frame number in the top left corner
processed_frame = cv2.putText(
processed_frame,
"{:d}".format(frame_number),
[50, 50],
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=1,
color=(0, 0, 255),
thickness=2,
)
print(f"Frame {frame_number} processed.")
# Store the processed result
with self.output_lock:
self.latest_output_frame = processed_frame
def process(self, frame):
"""
Main processing function called by Gradio stream.
Stores incoming frame and returns latest processed result.
"""
current_time = time.time()
if self.first_call_time is None:
self.first_call_time = current_time
# Store the new frame in the input slot (replacing any existing frame)
with self.input_lock:
self.latest_input_frame = frame.copy()
self.frame_counter += 1
input_unique_hash = hashlib.md5(frame.tobytes()).hexdigest()
# print(f"Input unique hash: {input_unique_hash}")
# Signal that a new frame is available for processing
self.new_frame_signal.set()
# Return the latest processed output, or original frame if no processing done yet
with self.output_lock:
if self.latest_output_frame is not None:
output_unique_hash = hashlib.md5(self.latest_output_frame.tobytes()).hexdigest()
# print(f"Output unique hash: {output_unique_hash}")
return self.latest_output_frame
else:
# Add indicator that this is unprocessed
temp_frame = frame.copy()
cv2.putText(
temp_frame,
f"Waiting... {self.frame_counter}",
(50, 50),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 0, 0), # Red for unprocessed frames
2,
)
return temp_frame
def stop(self):
"""Stop the processing thread"""
self.stop_event.set()
if self.processing_thread and self.processing_thread.is_alive():
self.processing_thread.join(timeout=2.0)
# CSS for styling the Gradio interface
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
# Initialize the asynchronous frame processor
frame_processor = AsyncFrameProcessor(processing_delay=0.5)
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
ProbPose Webcam Demo (CVPR 2025)
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
See <a href="https://MiraPurkrabek.github.io/ProbPose/" target="_blank">https://MiraPurkrabek.github.io/ProbPose/</a> for details.
</h3>
"""
)
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# webcam_stream = WebRTC(
# label="Webcam Stream",
# rtc_configuration=rtc_configuration,
# track_constraints=webcam_constraints,
# mirror_webcam=True,
# )
# # Stream processing: connects webcam input to frame processor
# webcam_stream.stream(
# fn=frame_processor.process,
# inputs=[webcam_stream],
# outputs=[webcam_stream],
# time_limit=120, # Limit processing time to 120 seconds
# )
with gr.Row():
with gr.Column():
input_img = gr.Image(sources=["webcam"], type="numpy")
with gr.Column():
output_img = gr.Image(streaming=True)
dep = input_img.stream(frame_processor.process, [input_img], [output_img],
time_limit=30, concurrency_limit=30)
if __name__ == "__main__":
demo.launch(
# server_name="0.0.0.0",
# server_port=17860,
# share=True
) |