chongzhou commited on
Commit
917decd
·
1 Parent(s): 282a45a

set device in inference_state manually

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +51 -43
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  *.egg-info/
2
  __pycache__/
 
 
1
  *.egg-info/
2
  __pycache__/
3
+ *.DS_Store
app.py CHANGED
@@ -174,6 +174,8 @@ def preprocess_video_in(
174
  input_labels,
175
  inference_state,
176
  ):
 
 
177
  if video_path is None:
178
  return (
179
  gr.update(open=True), # video_in_drawer
@@ -255,6 +257,8 @@ def segment_with_points(
255
  inference_state,
256
  evt: gr.SelectData,
257
  ):
 
 
258
  input_points.append(evt.index)
259
  print(f"TRACKING INPUT POINT: {input_points}")
260
 
@@ -336,55 +340,59 @@ def propagate_to_all(
336
  input_points,
337
  inference_state,
338
  ):
339
- # torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
340
  if torch.cuda.get_device_properties(0).major >= 8:
341
  torch.backends.cuda.matmul.allow_tf32 = True
342
  torch.backends.cudnn.allow_tf32 = True
343
-
344
- predictor.to("cuda")
345
-
346
- if len(input_points) == 0 or video_in is None or inference_state is None:
347
- return None
348
- # run propagation throughout the video and collect the results in a dict
349
- video_segments = {} # video_segments contains the per-frame segmentation results
350
- print("starting propagate_in_video")
351
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
352
- inference_state
353
- ):
354
- video_segments[out_frame_idx] = {
355
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
356
- for i, out_obj_id in enumerate(out_obj_ids)
357
- }
358
-
359
- # obtain the segmentation results every few frames
360
- vis_frame_stride = 1
361
-
362
- output_frames = []
363
- for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
364
- transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
365
- "RGBA"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  )
367
- out_mask = video_segments[out_frame_idx][OBJ_ID]
368
- mask_image = show_mask(out_mask)
369
- output_frame = Image.alpha_composite(transparent_background, mask_image)
370
- output_frame = np.array(output_frame)
371
- output_frames.append(output_frame)
372
-
373
- torch.cuda.empty_cache()
374
-
375
- # Create a video clip from the image sequence
376
- original_fps = get_video_fps(video_in)
377
- fps = original_fps # Frames per second
378
- clip = ImageSequenceClip(output_frames, fps=fps)
379
- # Write the result to a file
380
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
381
- final_vid_output_path = f"output_video_{unique_id}.mp4"
382
- final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
383
 
384
- # Write the result to a file
385
- clip.write_videofile(final_vid_output_path, codec="libx264")
386
 
387
- return gr.update(value=final_vid_output_path)
388
 
389
 
390
  def update_ui():
 
174
  input_labels,
175
  inference_state,
176
  ):
177
+ predictor.to("cpu")
178
+ inference_state["device"] = predictor.device
179
  if video_path is None:
180
  return (
181
  gr.update(open=True), # video_in_drawer
 
257
  inference_state,
258
  evt: gr.SelectData,
259
  ):
260
+ predictor.to("cpu")
261
+ inference_state["device"] = predictor.device
262
  input_points.append(evt.index)
263
  print(f"TRACKING INPUT POINT: {input_points}")
264
 
 
340
  input_points,
341
  inference_state,
342
  ):
 
343
  if torch.cuda.get_device_properties(0).major >= 8:
344
  torch.backends.cuda.matmul.allow_tf32 = True
345
  torch.backends.cudnn.allow_tf32 = True
346
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
347
+ predictor.to("cuda")
348
+ inference_state["device"] = predictor.device
349
+
350
+ if len(input_points) == 0 or video_in is None or inference_state is None:
351
+ return None
352
+ # run propagation throughout the video and collect the results in a dict
353
+ video_segments = (
354
+ {}
355
+ ) # video_segments contains the per-frame segmentation results
356
+ print("starting propagate_in_video")
357
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
358
+ inference_state
359
+ ):
360
+ video_segments[out_frame_idx] = {
361
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
362
+ for i, out_obj_id in enumerate(out_obj_ids)
363
+ }
364
+
365
+ # obtain the segmentation results every few frames
366
+ vis_frame_stride = 1
367
+
368
+ output_frames = []
369
+ for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
370
+ transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
371
+ "RGBA"
372
+ )
373
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
374
+ mask_image = show_mask(out_mask)
375
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
376
+ output_frame = np.array(output_frame)
377
+ output_frames.append(output_frame)
378
+
379
+ torch.cuda.empty_cache()
380
+
381
+ # Create a video clip from the image sequence
382
+ original_fps = get_video_fps(video_in)
383
+ fps = original_fps # Frames per second
384
+ clip = ImageSequenceClip(output_frames, fps=fps)
385
+ # Write the result to a file
386
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
387
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
388
+ final_vid_output_path = os.path.join(
389
+ tempfile.gettempdir(), final_vid_output_path
390
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
+ # Write the result to a file
393
+ clip.write_videofile(final_vid_output_path, codec="libx264")
394
 
395
+ return gr.update(value=final_vid_output_path)
396
 
397
 
398
  def update_ui():