Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
5b5416f
1
Parent(s):
ba25bef
update for zero gpu
Browse files- README.md +1 -1
- app.py +67 -116
- requirements.txt +1 -1
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: 👀
|
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.47.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
app.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
| 1 |
import colorsys
|
| 2 |
import gc
|
|
|
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
import torch
|
| 8 |
from gradio.themes import Soft
|
| 9 |
from PIL import Image, ImageDraw
|
| 10 |
|
| 11 |
# Prefer local transformers in the workspace
|
| 12 |
-
from transformers import
|
| 13 |
|
| 14 |
|
| 15 |
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
|
|
@@ -52,10 +54,12 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
|
|
| 52 |
cap.release()
|
| 53 |
if fps_val and fps_val > 0:
|
| 54 |
info["fps"] = float(fps_val)
|
| 55 |
-
except Exception:
|
|
|
|
| 56 |
pass
|
| 57 |
return pil_frames, info
|
| 58 |
-
except Exception:
|
|
|
|
| 59 |
# Fallback to OpenCV
|
| 60 |
try:
|
| 61 |
import cv2 # type: ignore
|
|
@@ -115,7 +119,7 @@ def overlay_masks_on_frame(
|
|
| 115 |
|
| 116 |
|
| 117 |
def get_device_and_dtype() -> tuple[str, torch.dtype]:
|
| 118 |
-
device = "
|
| 119 |
dtype = torch.bfloat16
|
| 120 |
return device, dtype
|
| 121 |
|
|
@@ -127,9 +131,9 @@ class AppState:
|
|
| 127 |
def reset(self):
|
| 128 |
self.video_frames: list[Image.Image] = []
|
| 129 |
self.inference_session = None
|
| 130 |
-
self.model: Optional[
|
| 131 |
self.processor: Optional[Sam2VideoProcessor] = None
|
| 132 |
-
self.device: str = "
|
| 133 |
self.dtype: torch.dtype = torch.bfloat16
|
| 134 |
self.video_fps: float | None = None
|
| 135 |
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
|
|
@@ -153,6 +157,9 @@ class AppState:
|
|
| 153 |
self.model_repo_id: str | None = None
|
| 154 |
self.session_repo_id: str | None = None
|
| 155 |
|
|
|
|
|
|
|
|
|
|
| 156 |
@property
|
| 157 |
def num_frames(self) -> int:
|
| 158 |
return len(self.video_frames)
|
|
@@ -168,29 +175,18 @@ def _model_repo_from_key(key: str) -> str:
|
|
| 168 |
return mapping.get(key, mapping["base_plus"])
|
| 169 |
|
| 170 |
|
| 171 |
-
def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[
|
| 172 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
| 173 |
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
|
| 174 |
if GLOBAL_STATE.model_repo_id == desired_repo:
|
| 175 |
return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
|
| 176 |
# Different repo requested: dispose current and reload
|
| 177 |
-
try:
|
| 178 |
-
del GLOBAL_STATE.model
|
| 179 |
-
except Exception:
|
| 180 |
-
pass
|
| 181 |
-
try:
|
| 182 |
-
del GLOBAL_STATE.processor
|
| 183 |
-
except Exception:
|
| 184 |
-
pass
|
| 185 |
GLOBAL_STATE.model = None
|
| 186 |
GLOBAL_STATE.processor = None
|
| 187 |
print(f"Loading model from {desired_repo}")
|
| 188 |
device, dtype = get_device_and_dtype()
|
| 189 |
# free up the gpu memory
|
| 190 |
-
|
| 191 |
-
gc.collect()
|
| 192 |
-
print("device", device)
|
| 193 |
-
model = Sam2VideoModel.from_pretrained(desired_repo)
|
| 194 |
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
|
| 195 |
model.to(device, dtype=dtype)
|
| 196 |
|
|
@@ -216,24 +212,11 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
|
|
| 216 |
GLOBAL_STATE.clicks_by_frame_obj.clear()
|
| 217 |
GLOBAL_STATE.boxes_by_frame_obj.clear()
|
| 218 |
GLOBAL_STATE.composited_frames.clear()
|
| 219 |
-
# Dispose previous session cleanly
|
| 220 |
-
try:
|
| 221 |
-
if GLOBAL_STATE.inference_session is not None:
|
| 222 |
-
GLOBAL_STATE.inference_session.reset_inference_session()
|
| 223 |
-
except Exception:
|
| 224 |
-
pass
|
| 225 |
GLOBAL_STATE.inference_session = None
|
| 226 |
-
gc.collect()
|
| 227 |
-
try:
|
| 228 |
-
if torch.cuda.is_available():
|
| 229 |
-
torch.cuda.empty_cache()
|
| 230 |
-
except Exception:
|
| 231 |
-
pass
|
| 232 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 233 |
-
video=GLOBAL_STATE.video_frames,
|
| 234 |
inference_device=device,
|
| 235 |
video_storage_device="cpu",
|
| 236 |
-
|
| 237 |
)
|
| 238 |
GLOBAL_STATE.session_repo_id = desired_repo
|
| 239 |
|
|
@@ -267,43 +250,21 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
|
|
| 267 |
# Enforce max duration of 8 seconds (trim if longer)
|
| 268 |
MAX_SECONDS = 8.0
|
| 269 |
trimmed_note = ""
|
| 270 |
-
fps_in =
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
max_frames_allowed = int(MAX_SECONDS * fps_in)
|
| 278 |
-
if len(frames) > max_frames_allowed:
|
| 279 |
-
frames = frames[:max_frames_allowed]
|
| 280 |
-
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
| 281 |
-
if isinstance(info, dict):
|
| 282 |
-
info["num_frames"] = len(frames)
|
| 283 |
-
else:
|
| 284 |
-
# Fallback when FPS unknown: assume ~30 FPS and cap to 240 frames (~8s)
|
| 285 |
-
max_frames_allowed = 240
|
| 286 |
-
if len(frames) > max_frames_allowed:
|
| 287 |
-
frames = frames[:max_frames_allowed]
|
| 288 |
-
trimmed_note = " (trimmed to 240 frames ~8s @30fps)"
|
| 289 |
-
if isinstance(info, dict):
|
| 290 |
-
info["num_frames"] = len(frames)
|
| 291 |
-
|
| 292 |
GLOBAL_STATE.video_frames = frames
|
| 293 |
# Try to capture original FPS if provided by loader
|
| 294 |
-
GLOBAL_STATE.video_fps =
|
| 295 |
-
if isinstance(info, dict) and info.get("fps"):
|
| 296 |
-
try:
|
| 297 |
-
GLOBAL_STATE.video_fps = float(info["fps"]) or None
|
| 298 |
-
except Exception:
|
| 299 |
-
GLOBAL_STATE.video_fps = None
|
| 300 |
-
|
| 301 |
# Initialize session
|
| 302 |
inference_session = processor.init_video_session(
|
| 303 |
-
video=frames,
|
| 304 |
inference_device=device,
|
| 305 |
video_storage_device="cpu",
|
| 306 |
-
|
| 307 |
)
|
| 308 |
GLOBAL_STATE.inference_session = inference_session
|
| 309 |
|
|
@@ -414,6 +375,12 @@ def on_image_click(
|
|
| 414 |
processor = state.processor
|
| 415 |
model = state.model
|
| 416 |
inference_session = state.inference_session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
if state.current_prompt_type == "Boxes":
|
| 419 |
# Two-click box input
|
|
@@ -445,6 +412,7 @@ def on_image_click(
|
|
| 445 |
obj_ids=int(obj_id),
|
| 446 |
input_boxes=[[[x_min, y_min, x_max, y_max]]],
|
| 447 |
clear_old_inputs=True, # For boxes, always clear old inputs
|
|
|
|
| 448 |
)
|
| 449 |
|
| 450 |
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
|
|
@@ -467,6 +435,7 @@ def on_image_click(
|
|
| 467 |
obj_ids=int(obj_id),
|
| 468 |
input_points=[[[[int(x), int(y)]]]],
|
| 469 |
input_labels=[[[int(label_int)]]],
|
|
|
|
| 470 |
clear_old_inputs=bool(clear_old),
|
| 471 |
)
|
| 472 |
|
|
@@ -478,12 +447,8 @@ def on_image_click(
|
|
| 478 |
state.composited_frames.pop(int(frame_idx), None)
|
| 479 |
|
| 480 |
# Forward on that frame
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
outputs = model(
|
| 484 |
-
inference_session=inference_session,
|
| 485 |
-
frame_idx=int(frame_idx),
|
| 486 |
-
)
|
| 487 |
|
| 488 |
H = inference_session.video_height
|
| 489 |
W = inference_session.video_width
|
|
@@ -509,31 +474,37 @@ def on_image_click(
|
|
| 509 |
return update_frame_display(state, int(frame_idx))
|
| 510 |
|
| 511 |
|
|
|
|
| 512 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 513 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 514 |
-
yield "Load a video first.", gr.update()
|
| 515 |
-
return
|
| 516 |
|
| 517 |
-
processor = GLOBAL_STATE.processor
|
| 518 |
-
model = GLOBAL_STATE.model
|
| 519 |
-
inference_session = GLOBAL_STATE.inference_session
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
total = max(1, GLOBAL_STATE.num_frames)
|
| 522 |
processed = 0
|
| 523 |
|
| 524 |
# Initial status; no slider change yet
|
| 525 |
-
yield f"Propagating masks: {processed}/{total}", gr.update()
|
| 526 |
|
| 527 |
-
device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
|
| 528 |
last_frame_idx = 0
|
| 529 |
-
with torch.inference_mode()
|
| 530 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
H = inference_session.video_height
|
| 532 |
W = inference_session.video_width
|
| 533 |
pred_masks = sam2_video_output.pred_masks.detach().cpu()
|
| 534 |
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
|
| 535 |
-
|
| 536 |
-
frame_idx = int(sam2_video_output.frame_idx)
|
| 537 |
last_frame_idx = frame_idx
|
| 538 |
masks_for_frame: dict[int, np.ndarray] = {}
|
| 539 |
obj_ids_order = list(inference_session.obj_ids)
|
|
@@ -546,16 +517,13 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 546 |
|
| 547 |
processed += 1
|
| 548 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 549 |
-
if processed %
|
| 550 |
-
yield f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 551 |
-
|
| 552 |
-
|
| 553 |
|
| 554 |
# Final status; ensure slider points to last processed frame
|
| 555 |
-
yield (
|
| 556 |
-
f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
|
| 557 |
-
gr.update(value=last_frame_idx),
|
| 558 |
-
)
|
| 559 |
|
| 560 |
|
| 561 |
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
|
|
@@ -581,11 +549,6 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 581 |
pass
|
| 582 |
GLOBAL_STATE.inference_session = None
|
| 583 |
gc.collect()
|
| 584 |
-
try:
|
| 585 |
-
if torch.cuda.is_available():
|
| 586 |
-
torch.cuda.empty_cache()
|
| 587 |
-
except Exception:
|
| 588 |
-
pass
|
| 589 |
ensure_session_for_current_model(GLOBAL_STATE)
|
| 590 |
|
| 591 |
# Keep current slider index if possible
|
|
@@ -786,29 +749,17 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 786 |
out_path = "/tmp/sam2_playback.mp4"
|
| 787 |
# Prefer imageio with PyAV/ffmpeg to respect exact fps
|
| 788 |
try:
|
| 789 |
-
import
|
| 790 |
|
| 791 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
return out_path
|
| 793 |
-
except Exception:
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
import imageio.v2 as imageio # type: ignore
|
| 797 |
-
|
| 798 |
-
imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
|
| 799 |
-
return out_path
|
| 800 |
-
except Exception:
|
| 801 |
-
try:
|
| 802 |
-
import cv2 # type: ignore
|
| 803 |
-
|
| 804 |
-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 805 |
-
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
|
| 806 |
-
for fr_bgr in frames_np:
|
| 807 |
-
writer.write(fr_bgr)
|
| 808 |
-
writer.release()
|
| 809 |
-
return out_path
|
| 810 |
-
except Exception as e:
|
| 811 |
-
raise gr.Error(f"Failed to render video: {e}")
|
| 812 |
|
| 813 |
render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
|
| 814 |
|
|
@@ -816,7 +767,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 816 |
propagate_btn.click(
|
| 817 |
propagate_masks,
|
| 818 |
inputs=[GLOBAL_STATE],
|
| 819 |
-
outputs=[propagate_status, frame_slider],
|
| 820 |
)
|
| 821 |
|
| 822 |
reset_btn.click(
|
|
|
|
| 1 |
import colorsys
|
| 2 |
import gc
|
| 3 |
+
from copy import deepcopy
|
| 4 |
from typing import Optional
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import numpy as np
|
| 8 |
+
import spaces
|
| 9 |
import torch
|
| 10 |
from gradio.themes import Soft
|
| 11 |
from PIL import Image, ImageDraw
|
| 12 |
|
| 13 |
# Prefer local transformers in the workspace
|
| 14 |
+
from transformers import AutoModel, Sam2VideoProcessor
|
| 15 |
|
| 16 |
|
| 17 |
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
|
|
|
|
| 54 |
cap.release()
|
| 55 |
if fps_val and fps_val > 0:
|
| 56 |
info["fps"] = float(fps_val)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Failed to render video with cv2: {e}")
|
| 59 |
pass
|
| 60 |
return pil_frames, info
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Failed to load video with transformers.video_utils: {e}")
|
| 63 |
# Fallback to OpenCV
|
| 64 |
try:
|
| 65 |
import cv2 # type: ignore
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def get_device_and_dtype() -> tuple[str, torch.dtype]:
|
| 122 |
+
device = "cpu"
|
| 123 |
dtype = torch.bfloat16
|
| 124 |
return device, dtype
|
| 125 |
|
|
|
|
| 131 |
def reset(self):
|
| 132 |
self.video_frames: list[Image.Image] = []
|
| 133 |
self.inference_session = None
|
| 134 |
+
self.model: Optional[AutoModel] = None
|
| 135 |
self.processor: Optional[Sam2VideoProcessor] = None
|
| 136 |
+
self.device: str = "cpu"
|
| 137 |
self.dtype: torch.dtype = torch.bfloat16
|
| 138 |
self.video_fps: float | None = None
|
| 139 |
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
|
|
|
|
| 157 |
self.model_repo_id: str | None = None
|
| 158 |
self.session_repo_id: str | None = None
|
| 159 |
|
| 160 |
+
def __repr__(self):
|
| 161 |
+
return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
|
| 162 |
+
|
| 163 |
@property
|
| 164 |
def num_frames(self) -> int:
|
| 165 |
return len(self.video_frames)
|
|
|
|
| 175 |
return mapping.get(key, mapping["base_plus"])
|
| 176 |
|
| 177 |
|
| 178 |
+
def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]:
|
| 179 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
| 180 |
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
|
| 181 |
if GLOBAL_STATE.model_repo_id == desired_repo:
|
| 182 |
return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
|
| 183 |
# Different repo requested: dispose current and reload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
GLOBAL_STATE.model = None
|
| 185 |
GLOBAL_STATE.processor = None
|
| 186 |
print(f"Loading model from {desired_repo}")
|
| 187 |
device, dtype = get_device_and_dtype()
|
| 188 |
# free up the gpu memory
|
| 189 |
+
model = AutoModel.from_pretrained(desired_repo)
|
|
|
|
|
|
|
|
|
|
| 190 |
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
|
| 191 |
model.to(device, dtype=dtype)
|
| 192 |
|
|
|
|
| 212 |
GLOBAL_STATE.clicks_by_frame_obj.clear()
|
| 213 |
GLOBAL_STATE.boxes_by_frame_obj.clear()
|
| 214 |
GLOBAL_STATE.composited_frames.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
GLOBAL_STATE.inference_session = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
|
|
|
| 217 |
inference_device=device,
|
| 218 |
video_storage_device="cpu",
|
| 219 |
+
dtype=dtype,
|
| 220 |
)
|
| 221 |
GLOBAL_STATE.session_repo_id = desired_repo
|
| 222 |
|
|
|
|
| 250 |
# Enforce max duration of 8 seconds (trim if longer)
|
| 251 |
MAX_SECONDS = 8.0
|
| 252 |
trimmed_note = ""
|
| 253 |
+
fps_in = info.get("fps")
|
| 254 |
+
max_frames_allowed = int(MAX_SECONDS * fps_in)
|
| 255 |
+
if len(frames) > max_frames_allowed:
|
| 256 |
+
frames = frames[:max_frames_allowed]
|
| 257 |
+
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
| 258 |
+
if isinstance(info, dict):
|
| 259 |
+
info["num_frames"] = len(frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
GLOBAL_STATE.video_frames = frames
|
| 261 |
# Try to capture original FPS if provided by loader
|
| 262 |
+
GLOBAL_STATE.video_fps = float(fps_in)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
# Initialize session
|
| 264 |
inference_session = processor.init_video_session(
|
|
|
|
| 265 |
inference_device=device,
|
| 266 |
video_storage_device="cpu",
|
| 267 |
+
dtype=dtype,
|
| 268 |
)
|
| 269 |
GLOBAL_STATE.inference_session = inference_session
|
| 270 |
|
|
|
|
| 375 |
processor = state.processor
|
| 376 |
model = state.model
|
| 377 |
inference_session = state.inference_session
|
| 378 |
+
original_size = None
|
| 379 |
+
pixel_values = None
|
| 380 |
+
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
|
| 381 |
+
inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt")
|
| 382 |
+
original_size = inputs.original_sizes[0]
|
| 383 |
+
pixel_values = inputs.pixel_values[0]
|
| 384 |
|
| 385 |
if state.current_prompt_type == "Boxes":
|
| 386 |
# Two-click box input
|
|
|
|
| 412 |
obj_ids=int(obj_id),
|
| 413 |
input_boxes=[[[x_min, y_min, x_max, y_max]]],
|
| 414 |
clear_old_inputs=True, # For boxes, always clear old inputs
|
| 415 |
+
original_size=original_size,
|
| 416 |
)
|
| 417 |
|
| 418 |
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
|
|
|
|
| 435 |
obj_ids=int(obj_id),
|
| 436 |
input_points=[[[[int(x), int(y)]]]],
|
| 437 |
input_labels=[[[int(label_int)]]],
|
| 438 |
+
original_size=original_size,
|
| 439 |
clear_old_inputs=bool(clear_old),
|
| 440 |
)
|
| 441 |
|
|
|
|
| 447 |
state.composited_frames.pop(int(frame_idx), None)
|
| 448 |
|
| 449 |
# Forward on that frame
|
| 450 |
+
with torch.inference_mode():
|
| 451 |
+
outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
H = inference_session.video_height
|
| 454 |
W = inference_session.video_width
|
|
|
|
| 474 |
return update_frame_display(state, int(frame_idx))
|
| 475 |
|
| 476 |
|
| 477 |
+
@spaces.GPU()
|
| 478 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 479 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 480 |
+
# yield GLOBAL_STATE, "Load a video first.", gr.update()
|
| 481 |
+
return GLOBAL_STATE, "Load a video first.", gr.update()
|
| 482 |
|
| 483 |
+
processor = deepcopy(GLOBAL_STATE.processor)
|
| 484 |
+
model = deepcopy(GLOBAL_STATE.model)
|
| 485 |
+
inference_session = deepcopy(GLOBAL_STATE.inference_session)
|
| 486 |
+
# set inference device to cuda to use zero gpu
|
| 487 |
+
inference_session.inference_device = "cuda"
|
| 488 |
+
inference_session.cache.inference_device = "cuda"
|
| 489 |
+
model.to("cuda")
|
| 490 |
|
| 491 |
total = max(1, GLOBAL_STATE.num_frames)
|
| 492 |
processed = 0
|
| 493 |
|
| 494 |
# Initial status; no slider change yet
|
| 495 |
+
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
|
| 496 |
|
|
|
|
| 497 |
last_frame_idx = 0
|
| 498 |
+
with torch.inference_mode():
|
| 499 |
+
for frame_idx, frame in enumerate(GLOBAL_STATE.video_frames):
|
| 500 |
+
pixel_values = None
|
| 501 |
+
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
|
| 502 |
+
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
|
| 503 |
+
sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx)
|
| 504 |
H = inference_session.video_height
|
| 505 |
W = inference_session.video_width
|
| 506 |
pred_masks = sam2_video_output.pred_masks.detach().cpu()
|
| 507 |
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
|
|
|
|
|
|
|
| 508 |
last_frame_idx = frame_idx
|
| 509 |
masks_for_frame: dict[int, np.ndarray] = {}
|
| 510 |
obj_ids_order = list(inference_session.obj_ids)
|
|
|
|
| 517 |
|
| 518 |
processed += 1
|
| 519 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 520 |
+
if processed % 30 == 0 or processed == total:
|
| 521 |
+
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 522 |
+
|
| 523 |
+
text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
|
| 524 |
|
| 525 |
# Final status; ensure slider points to last processed frame
|
| 526 |
+
yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
|
|
|
|
|
|
|
|
|
|
| 527 |
|
| 528 |
|
| 529 |
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
|
|
|
|
| 549 |
pass
|
| 550 |
GLOBAL_STATE.inference_session = None
|
| 551 |
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
ensure_session_for_current_model(GLOBAL_STATE)
|
| 553 |
|
| 554 |
# Keep current slider index if possible
|
|
|
|
| 749 |
out_path = "/tmp/sam2_playback.mp4"
|
| 750 |
# Prefer imageio with PyAV/ffmpeg to respect exact fps
|
| 751 |
try:
|
| 752 |
+
import cv2 # type: ignore
|
| 753 |
|
| 754 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 755 |
+
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
|
| 756 |
+
for fr_bgr in frames_np:
|
| 757 |
+
writer.write(fr_bgr)
|
| 758 |
+
writer.release()
|
| 759 |
return out_path
|
| 760 |
+
except Exception as e:
|
| 761 |
+
print(f"Failed to render video with cv2: {e}")
|
| 762 |
+
raise gr.Error(f"Failed to render video: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
|
| 764 |
render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
|
| 765 |
|
|
|
|
| 767 |
propagate_btn.click(
|
| 768 |
propagate_masks,
|
| 769 |
inputs=[GLOBAL_STATE],
|
| 770 |
+
outputs=[GLOBAL_STATE, propagate_status, frame_slider],
|
| 771 |
)
|
| 772 |
|
| 773 |
reset_btn.click(
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
gradio
|
| 2 |
-
git+https://github.com/
|
| 3 |
torch
|
| 4 |
torchvision
|
| 5 |
pillow
|
|
|
|
| 1 |
gradio
|
| 2 |
+
git+https://github.com/yonigozlan/transformers.git@add-edgetam
|
| 3 |
torch
|
| 4 |
torchvision
|
| 5 |
pillow
|