Miroslav Purkrabek commited on
Commit
4241c11
·
1 Parent(s): ebb18ad

init models first

Browse files
Files changed (1) hide show
  1. app.py +29 -53
app.py CHANGED
@@ -50,38 +50,34 @@ POSE_WEIGHTS = "models/ProbPose-s.pth"
50
  DEVICE = 'cuda:0'
51
  # DEVICE = 'cpu'
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @spaces.GPU
54
- def init_models():
55
- """Initialize and return models for pose estimation"""
56
- # Init detector
57
- print("Initializing MMDetection detector...")
58
- det_model = init_detector(DET_CFG, DET_WEIGHTS, device=DEVICE)
59
- det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
60
- print("Detector initialized successfully!")
61
-
62
- # Init pose estimator
63
- print("Initializing MMPose estimator...")
64
- pose_model = init_pose_estimator(
65
- POSE_CFG,
66
- POSE_WEIGHTS,
67
- device=DEVICE,
68
- cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=True)))
69
- )
70
-
71
- # Build visualizer
72
- pose_model.cfg.visualizer.radius = 4
73
- pose_model.cfg.visualizer.alpha = 0.8
74
- pose_model.cfg.visualizer.line_width = 2
75
- visualizer = VISUALIZERS.build(pose_model.cfg.visualizer)
76
- visualizer.set_dataset_meta(
77
- pose_model.dataset_meta, skeleton_style='mmpose'
78
- )
79
- print("Pose estimator initialized successfully!")
80
- return det_model, pose_model, visualizer
81
-
82
- @spaces.GPU
83
- def process_frame(frame, det_model, pose_model, visualizer, bbox_thr=0.3, nms_thr=0.8, kpt_thr=0.3):
84
  """Process a single frame with pose estimation"""
 
 
85
  processing_start = time.time()
86
 
87
  # Mirror the frame
@@ -173,14 +169,6 @@ class AsyncFrameProcessor:
173
  self.stop_event = threading.Event()
174
  self.new_frame_signal = threading.Event()
175
 
176
- # Detector and pose estimator models
177
- self.pose_model = None
178
- self.det_model = None
179
- self.visualizer = None
180
-
181
- # Initialize models using the standalone function
182
- self.det_model, self.pose_model, self.visualizer = init_models()
183
-
184
  # Start background processing
185
  self._start_processing_thread()
186
 
@@ -213,20 +201,8 @@ class AsyncFrameProcessor:
213
 
214
  print(f"Processing frame number: {frame_number}")
215
 
216
- # Initialize models if not already done
217
- if self.det_model is None or self.pose_model is None or self.visualizer is None:
218
- self.det_model, self.pose_model, self.visualizer = init_models()
219
-
220
- print("Models initialized, starting frame processing...")
221
-
222
-
223
- # Process the frame using the standalone function
224
- processed_frame = process_frame(
225
- frame_to_process,
226
- self.det_model,
227
- self.pose_model,
228
- self.visualizer
229
- )
230
 
231
  # Write frame number in the top left corner
232
  processed_frame = cv2.putText(
@@ -341,7 +317,7 @@ with gr.Blocks(css=css) as demo:
341
 
342
 
343
  if __name__ == "__main__":
344
-
345
  demo.launch(
346
  # server_name="0.0.0.0",
347
  # server_port=17860,
 
50
  DEVICE = 'cuda:0'
51
  # DEVICE = 'cpu'
52
 
53
+ # Global variables for models
54
+ print("Initializing MMDetection detector...")
55
+ det_model = init_detector(DET_CFG, DET_WEIGHTS, device=DEVICE)
56
+ det_model.cfg = adapt_mmdet_pipeline(det_model.cfg)
57
+ print("Detector initialized successfully!")
58
+
59
+ print("Initializing MMPose estimator...")
60
+ pose_model = init_pose_estimator(
61
+ POSE_CFG,
62
+ POSE_WEIGHTS,
63
+ device=DEVICE,
64
+ cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=True)))
65
+ )
66
+ print("Pose estimator initialized successfully!")
67
+
68
+ pose_model.cfg.visualizer.radius = 4
69
+ pose_model.cfg.visualizer.alpha = 0.8
70
+ pose_model.cfg.visualizer.line_width = 2
71
+ visualizer = VISUALIZERS.build(pose_model.cfg.visualizer)
72
+ visualizer.set_dataset_meta(
73
+ pose_model.dataset_meta, skeleton_style='mmpose'
74
+ )
75
+
76
  @spaces.GPU
77
+ def process_frame(frame, bbox_thr=0.3, nms_thr=0.8, kpt_thr=0.3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  """Process a single frame with pose estimation"""
79
+ global det_model, pose_model, visualizer
80
+
81
  processing_start = time.time()
82
 
83
  # Mirror the frame
 
169
  self.stop_event = threading.Event()
170
  self.new_frame_signal = threading.Event()
171
 
 
 
 
 
 
 
 
 
172
  # Start background processing
173
  self._start_processing_thread()
174
 
 
201
 
202
  print(f"Processing frame number: {frame_number}")
203
 
204
+ # Process the frame using the global models
205
+ processed_frame = process_frame(frame_to_process)
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  # Write frame number in the top left corner
208
  processed_frame = cv2.putText(
 
317
 
318
 
319
  if __name__ == "__main__":
320
+
321
  demo.launch(
322
  # server_name="0.0.0.0",
323
  # server_port=17860,