hysts HF Staff commited on
Commit
6228bae
·
1 Parent(s): 2f7dd3f
Files changed (3) hide show
  1. app.py +35 -38
  2. pyproject.toml +60 -0
  3. uv.lock +0 -0
app.py CHANGED
@@ -8,7 +8,6 @@ import numpy as np
8
  import torch
9
  from gradio.themes import Soft
10
  from PIL import Image, ImageDraw, ImageFont
11
-
12
  from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
13
 
14
 
@@ -153,7 +152,7 @@ def init_video_session(
153
  device = _GLOBAL_DEVICE
154
  dtype = _GLOBAL_DTYPE
155
 
156
- video_path: Optional[str] = None
157
  if isinstance(video, dict):
158
  video_path = video.get("name") or video.get("path") or video.get("data")
159
  elif isinstance(video, str):
@@ -286,7 +285,7 @@ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
286
  try:
287
  font = ImageFont.truetype(font_path, font_size)
288
  break
289
- except (OSError, IOError):
290
  continue
291
  if font is None:
292
  # Fallback to default font
@@ -338,7 +337,7 @@ def update_frame_display(state: AppState, frame_idx: int) -> Image.Image:
338
  return compose_frame(state, frame_idx)
339
 
340
 
341
- def _get_prompt_for_obj(state: AppState, obj_id: int) -> Optional[str]:
342
  """Get the prompt text associated with an object ID."""
343
  # Priority 1: Check text_prompts_by_frame_obj (most reliable)
344
  for frame_texts in state.text_prompts_by_frame_obj.values():
@@ -415,29 +414,28 @@ def on_image_click(
415
  state.pending_box_start_obj_id = ann_obj_id
416
  state.composited_frames.pop(ann_frame_idx, None)
417
  return update_frame_display(state, ann_frame_idx)
418
- else:
419
- x1, y1 = state.pending_box_start
420
- x2, y2 = int(x), int(y)
421
- state.pending_box_start = None
422
- state.pending_box_start_frame_idx = None
423
- state.pending_box_start_obj_id = None
424
- state.composited_frames.pop(ann_frame_idx, None)
425
- x_min, y_min = min(x1, x2), min(y1, y2)
426
- x_max, y_max = max(x1, x2), max(y1, y2)
427
-
428
- box = [[[x_min, y_min, x_max, y_max]]]
429
- processor.add_inputs_to_inference_session(
430
- inference_session=state.inference_session,
431
- frame_idx=ann_frame_idx,
432
- obj_ids=ann_obj_id,
433
- input_boxes=box,
434
- )
435
 
436
- frame_boxes = state.boxes_by_frame_obj.setdefault(ann_frame_idx, {})
437
- obj_boxes = frame_boxes.setdefault(ann_obj_id, [])
438
- obj_boxes.clear()
439
- obj_boxes.append((x_min, y_min, x_max, y_max))
440
- state.composited_frames.pop(ann_frame_idx, None)
 
 
 
 
 
 
 
 
441
  else:
442
  label_int = 1 if str(label).lower().startswith("pos") else 0
443
 
@@ -654,7 +652,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
654
  return
655
 
656
  # Add all prompts to the inference session (processor handles deduplication)
657
- for text_prompt in text_prompt_to_obj_ids.keys():
658
  GLOBAL_STATE.inference_session = processor.add_text_prompt(
659
  inference_session=GLOBAL_STATE.inference_session,
660
  text=text_prompt,
@@ -840,17 +838,16 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
840
  GLOBAL_STATE.inference_session, "reset_inference_session"
841
  ):
842
  GLOBAL_STATE.inference_session.reset_inference_session()
843
- else:
844
- if GLOBAL_STATE.video_frames:
845
- processor = _GLOBAL_TRACKER_PROCESSOR
846
- raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
847
- GLOBAL_STATE.inference_session = processor.init_video_session(
848
- video=raw_video,
849
- inference_device=_GLOBAL_DEVICE,
850
- video_storage_device="cpu",
851
- processing_device="cpu",
852
- dtype=_GLOBAL_DTYPE,
853
- )
854
 
855
  GLOBAL_STATE.masks_by_frame.clear()
856
  GLOBAL_STATE.clicks_by_frame_obj.clear()
 
8
  import torch
9
  from gradio.themes import Soft
10
  from PIL import Image, ImageDraw, ImageFont
 
11
  from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
12
 
13
 
 
152
  device = _GLOBAL_DEVICE
153
  dtype = _GLOBAL_DTYPE
154
 
155
+ video_path: str | None = None
156
  if isinstance(video, dict):
157
  video_path = video.get("name") or video.get("path") or video.get("data")
158
  elif isinstance(video, str):
 
285
  try:
286
  font = ImageFont.truetype(font_path, font_size)
287
  break
288
+ except OSError:
289
  continue
290
  if font is None:
291
  # Fallback to default font
 
337
  return compose_frame(state, frame_idx)
338
 
339
 
340
+ def _get_prompt_for_obj(state: AppState, obj_id: int) -> str | None:
341
  """Get the prompt text associated with an object ID."""
342
  # Priority 1: Check text_prompts_by_frame_obj (most reliable)
343
  for frame_texts in state.text_prompts_by_frame_obj.values():
 
414
  state.pending_box_start_obj_id = ann_obj_id
415
  state.composited_frames.pop(ann_frame_idx, None)
416
  return update_frame_display(state, ann_frame_idx)
417
+ x1, y1 = state.pending_box_start
418
+ x2, y2 = int(x), int(y)
419
+ state.pending_box_start = None
420
+ state.pending_box_start_frame_idx = None
421
+ state.pending_box_start_obj_id = None
422
+ state.composited_frames.pop(ann_frame_idx, None)
423
+ x_min, y_min = min(x1, x2), min(y1, y2)
424
+ x_max, y_max = max(x1, x2), max(y1, y2)
 
 
 
 
 
 
 
 
 
425
 
426
+ box = [[[x_min, y_min, x_max, y_max]]]
427
+ processor.add_inputs_to_inference_session(
428
+ inference_session=state.inference_session,
429
+ frame_idx=ann_frame_idx,
430
+ obj_ids=ann_obj_id,
431
+ input_boxes=box,
432
+ )
433
+
434
+ frame_boxes = state.boxes_by_frame_obj.setdefault(ann_frame_idx, {})
435
+ obj_boxes = frame_boxes.setdefault(ann_obj_id, [])
436
+ obj_boxes.clear()
437
+ obj_boxes.append((x_min, y_min, x_max, y_max))
438
+ state.composited_frames.pop(ann_frame_idx, None)
439
  else:
440
  label_int = 1 if str(label).lower().startswith("pos") else 0
441
 
 
652
  return
653
 
654
  # Add all prompts to the inference session (processor handles deduplication)
655
+ for text_prompt in text_prompt_to_obj_ids:
656
  GLOBAL_STATE.inference_session = processor.add_text_prompt(
657
  inference_session=GLOBAL_STATE.inference_session,
658
  text=text_prompt,
 
838
  GLOBAL_STATE.inference_session, "reset_inference_session"
839
  ):
840
  GLOBAL_STATE.inference_session.reset_inference_session()
841
+ elif GLOBAL_STATE.video_frames:
842
+ processor = _GLOBAL_TRACKER_PROCESSOR
843
+ raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
844
+ GLOBAL_STATE.inference_session = processor.init_video_session(
845
+ video=raw_video,
846
+ inference_device=_GLOBAL_DEVICE,
847
+ video_storage_device="cpu",
848
+ processing_device="cpu",
849
+ dtype=_GLOBAL_DTYPE,
850
+ )
 
851
 
852
  GLOBAL_STATE.masks_by_frame.clear()
853
  GLOBAL_STATE.clicks_by_frame_obj.clear()
pyproject.toml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sam3-video-segmentation"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "accelerate>=1.11.0",
9
+ "gradio>=5.49.1",
10
+ "imageio[pyav]>=2.37.2",
11
+ "kernels>=0.11.0",
12
+ "opencv-python>=4.12.0.88",
13
+ "spaces>=0.42.1",
14
+ "torch==2.8.0",
15
+ "torchvision>=0.23.0",
16
+ "transformers",
17
+ ]
18
+
19
+ [tool.ruff]
20
+ line-length = 119
21
+
22
+ [tool.ruff.lint]
23
+ select = ["ALL"]
24
+ ignore = [
25
+ "COM812", # missing-trailing-comma
26
+ "D203", # one-blank-line-before-class
27
+ "D213", # multi-line-summary-second-line
28
+ "E501", # line-too-long
29
+ "SIM117", # multiple-with-statements
30
+ #
31
+ "D100", # undocumented-public-module
32
+ "D101", # undocumented-public-class
33
+ "D102", # undocumented-public-method
34
+ "D103", # undocumented-public-function
35
+ "D104", # undocumented-public-package
36
+ "D105", # undocumented-magic-method
37
+ "D107", # undocumented-public-init
38
+ "EM101", # raw-string-in-exception
39
+ "FBT001", # boolean-type-hint-positional-argument
40
+ "FBT002", # boolean-default-value-positional-argument
41
+ "PGH003", # blanket-type-ignore
42
+ "PLR0913", # too-many-arguments
43
+ "PLR0915", # too-many-statements
44
+ "TRY003", # raise-vanilla-args
45
+ ]
46
+ unfixable = [
47
+ "F401", # unused-import
48
+ ]
49
+
50
+ [tool.ruff.lint.pydocstyle]
51
+ convention = "google"
52
+
53
+ [tool.ruff.lint.per-file-ignores]
54
+ "*.ipynb" = ["T201", "T203"]
55
+
56
+ [tool.ruff.format]
57
+ docstring-code-format = true
58
+
59
+ [tool.uv.sources]
60
+ transformers = { git = "https://github.com/huggingface/transformers.git", rev = "69f003696b" }
uv.lock ADDED
The diff for this file is too large to render. See raw diff