set device in inference_state manually
Browse files- .gitignore +1 -0
- app.py +51 -43
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
*.egg-info/
|
2 |
__pycache__/
|
|
|
|
1 |
*.egg-info/
|
2 |
__pycache__/
|
3 |
+
*.DS_Store
|
app.py
CHANGED
@@ -174,6 +174,8 @@ def preprocess_video_in(
|
|
174 |
input_labels,
|
175 |
inference_state,
|
176 |
):
|
|
|
|
|
177 |
if video_path is None:
|
178 |
return (
|
179 |
gr.update(open=True), # video_in_drawer
|
@@ -255,6 +257,8 @@ def segment_with_points(
|
|
255 |
inference_state,
|
256 |
evt: gr.SelectData,
|
257 |
):
|
|
|
|
|
258 |
input_points.append(evt.index)
|
259 |
print(f"TRACKING INPUT POINT: {input_points}")
|
260 |
|
@@ -336,55 +340,59 @@ def propagate_to_all(
|
|
336 |
input_points,
|
337 |
inference_state,
|
338 |
):
|
339 |
-
# torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
340 |
if torch.cuda.get_device_properties(0).major >= 8:
|
341 |
torch.backends.cuda.matmul.allow_tf32 = True
|
342 |
torch.backends.cudnn.allow_tf32 = True
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
)
|
367 |
-
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
368 |
-
mask_image = show_mask(out_mask)
|
369 |
-
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
370 |
-
output_frame = np.array(output_frame)
|
371 |
-
output_frames.append(output_frame)
|
372 |
-
|
373 |
-
torch.cuda.empty_cache()
|
374 |
-
|
375 |
-
# Create a video clip from the image sequence
|
376 |
-
original_fps = get_video_fps(video_in)
|
377 |
-
fps = original_fps # Frames per second
|
378 |
-
clip = ImageSequenceClip(output_frames, fps=fps)
|
379 |
-
# Write the result to a file
|
380 |
-
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
381 |
-
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
382 |
-
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
|
383 |
|
384 |
-
|
385 |
-
|
386 |
|
387 |
-
|
388 |
|
389 |
|
390 |
def update_ui():
|
|
|
174 |
input_labels,
|
175 |
inference_state,
|
176 |
):
|
177 |
+
predictor.to("cpu")
|
178 |
+
inference_state["device"] = predictor.device
|
179 |
if video_path is None:
|
180 |
return (
|
181 |
gr.update(open=True), # video_in_drawer
|
|
|
257 |
inference_state,
|
258 |
evt: gr.SelectData,
|
259 |
):
|
260 |
+
predictor.to("cpu")
|
261 |
+
inference_state["device"] = predictor.device
|
262 |
input_points.append(evt.index)
|
263 |
print(f"TRACKING INPUT POINT: {input_points}")
|
264 |
|
|
|
340 |
input_points,
|
341 |
inference_state,
|
342 |
):
|
|
|
343 |
if torch.cuda.get_device_properties(0).major >= 8:
|
344 |
torch.backends.cuda.matmul.allow_tf32 = True
|
345 |
torch.backends.cudnn.allow_tf32 = True
|
346 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
347 |
+
predictor.to("cuda")
|
348 |
+
inference_state["device"] = predictor.device
|
349 |
+
|
350 |
+
if len(input_points) == 0 or video_in is None or inference_state is None:
|
351 |
+
return None
|
352 |
+
# run propagation throughout the video and collect the results in a dict
|
353 |
+
video_segments = (
|
354 |
+
{}
|
355 |
+
) # video_segments contains the per-frame segmentation results
|
356 |
+
print("starting propagate_in_video")
|
357 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
358 |
+
inference_state
|
359 |
+
):
|
360 |
+
video_segments[out_frame_idx] = {
|
361 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
362 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
363 |
+
}
|
364 |
+
|
365 |
+
# obtain the segmentation results every few frames
|
366 |
+
vis_frame_stride = 1
|
367 |
+
|
368 |
+
output_frames = []
|
369 |
+
for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
|
370 |
+
transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
|
371 |
+
"RGBA"
|
372 |
+
)
|
373 |
+
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
374 |
+
mask_image = show_mask(out_mask)
|
375 |
+
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
376 |
+
output_frame = np.array(output_frame)
|
377 |
+
output_frames.append(output_frame)
|
378 |
+
|
379 |
+
torch.cuda.empty_cache()
|
380 |
+
|
381 |
+
# Create a video clip from the image sequence
|
382 |
+
original_fps = get_video_fps(video_in)
|
383 |
+
fps = original_fps # Frames per second
|
384 |
+
clip = ImageSequenceClip(output_frames, fps=fps)
|
385 |
+
# Write the result to a file
|
386 |
+
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
387 |
+
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
388 |
+
final_vid_output_path = os.path.join(
|
389 |
+
tempfile.gettempdir(), final_vid_output_path
|
390 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
+
# Write the result to a file
|
393 |
+
clip.write_videofile(final_vid_output_path, codec="libx264")
|
394 |
|
395 |
+
return gr.update(value=final_vid_output_path)
|
396 |
|
397 |
|
398 |
def update_ui():
|