SAM_Track / app.py
aikenml's picture
Update app.py
90b65ad
################## My code
import subprocess
"""Replace with the correct path to your install.sh script"""
install_script_path = "script/install.sh"
"""Run the shell command"""
subprocess.run(["bash", install_script_path])
##################
from PIL.ImageOps import colorize, scale
import gradio as gr
import importlib
import sys
import os
import pdb
from matplotlib.pyplot import step
from model_args import segtracker_args,sam_args,aot_args
from SegTracker import SegTracker
from tool.transfer_tools import draw_outline, draw_points
# sys.path.append('.')
# sys.path.append('..')
import cv2
from PIL import Image
from skimage.morphology.binary import binary_dilation
import argparse
import torch
import time, math
from seg_track_anything import aot_model2ckpt, tracking_objects_in_video, draw_mask
import gc
import numpy as np
import json
from tool.transfer_tools import mask2bbox
def clean():
return None, None, None, None, None, None, [[], []]
def get_click_prompt(click_stack, point):
click_stack[0].append(point["coord"])
click_stack[1].append(point["mode"]
)
prompt = {
"points_coord":click_stack[0],
"points_mode":click_stack[1],
"multimask":"True",
}
return prompt
def get_meta_from_video(input_video):
if input_video is None:
return None, None, None, ""
print("get meta information of input video")
cap = cv2.VideoCapture(input_video)
_, first_frame = cap.read()
cap.release()
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
return first_frame, first_frame, first_frame, ""
def get_meta_from_img_seq(input_img_seq):
if input_img_seq is None:
return None, None, None, ""
print("get meta information of img seq")
# Create dir
file_name = input_img_seq.name.split('/')[-1].split('.')[0]
file_path = f'./assets/{file_name}'
if os.path.isdir(file_path):
os.system(f'rm -r {file_path}')
os.makedirs(file_path)
# Unzip file
os.system(f'unzip {input_img_seq.name} -d ./assets ')
imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)])
first_frame = imgs_path[0]
first_frame = cv2.imread(first_frame)
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
return first_frame, first_frame, first_frame, ""
def SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask):
with torch.cuda.amp.autocast():
# Reset the first frame's mask
frame_idx = 0
Seg_Tracker.restart_tracker()
Seg_Tracker.add_reference(origin_frame, predicted_mask, frame_idx)
Seg_Tracker.first_frame_mask = predicted_mask
return Seg_Tracker
def init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame):
if origin_frame is None:
return None, origin_frame, [[], []], ""
# reset aot args
aot_args["model"] = aot_model
aot_args["model_path"] = aot_model2ckpt[aot_model]
aot_args["long_term_mem_gap"] = long_term_mem
aot_args["max_len_long_term"] = max_len_long_term
# reset sam args
segtracker_args["sam_gap"] = sam_gap
segtracker_args["max_obj_num"] = max_obj_num
sam_args["generator_args"]["points_per_side"] = points_per_side
Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
Seg_Tracker.restart_tracker()
return Seg_Tracker, origin_frame, [[], []], ""
def init_SegTracker_Stroke(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame):
if origin_frame is None:
return None, origin_frame, [[], []], origin_frame
# reset aot args
aot_args["model"] = aot_model
aot_args["model_path"] = aot_model2ckpt[aot_model]
aot_args["long_term_mem_gap"] = long_term_mem
aot_args["max_len_long_term"] = max_len_long_term
# reset sam args
segtracker_args["sam_gap"] = sam_gap
segtracker_args["max_obj_num"] = max_obj_num
sam_args["generator_args"]["points_per_side"] = points_per_side
Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
Seg_Tracker.restart_tracker()
return Seg_Tracker, origin_frame, [[], []], origin_frame
def undo_click_stack_and_refine_seg(Seg_Tracker, origin_frame, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side):
if Seg_Tracker is None:
return Seg_Tracker, origin_frame, [[], []]
print("Undo!")
if len(click_stack[0]) > 0:
click_stack[0] = click_stack[0][: -1]
click_stack[1] = click_stack[1][: -1]
if len(click_stack[0]) > 0:
prompt = {
"points_coord":click_stack[0],
"points_mode":click_stack[1],
"multimask":"True",
}
masked_frame = seg_acc_click(Seg_Tracker, prompt, origin_frame)
return Seg_Tracker, masked_frame, click_stack
else:
return Seg_Tracker, origin_frame, [[], []]
def roll_back_undo_click_stack_and_refine_seg(Seg_Tracker, origin_frame, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side,input_video, input_img_seq, frame_num, refine_idx):
if Seg_Tracker is None:
return Seg_Tracker, origin_frame, [[], []]
print("Undo!")
if len(click_stack[0]) > 0:
click_stack[0] = click_stack[0][: -1]
click_stack[1] = click_stack[1][: -1]
if len(click_stack[0]) > 0:
prompt = {
"points_coord":click_stack[0],
"points_mode":click_stack[1],
"multimask":"True",
}
chosen_frame_show, curr_mask, ori_frame = res_by_num(input_video, input_img_seq, frame_num)
Seg_Tracker.curr_idx = refine_idx
predicted_mask, masked_frame = Seg_Tracker.seg_acc_click(
origin_frame=origin_frame,
coords=np.array(prompt["points_coord"]),
modes=np.array(prompt["points_mode"]),
multimask=prompt["multimask"],
)
curr_mask[curr_mask == refine_idx] = 0
curr_mask[predicted_mask != 0] = refine_idx
predicted_mask=curr_mask
Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask)
return Seg_Tracker, masked_frame, click_stack
else:
return Seg_Tracker, origin_frame, [[], []]
def seg_acc_click(Seg_Tracker, prompt, origin_frame):
# seg acc to click
predicted_mask, masked_frame = Seg_Tracker.seg_acc_click(
origin_frame=origin_frame,
coords=np.array(prompt["points_coord"]),
modes=np.array(prompt["points_mode"]),
multimask=prompt["multimask"],
)
Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask)
return masked_frame
def sam_click(Seg_Tracker, origin_frame, point_mode, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, evt:gr.SelectData):
"""
Args:
origin_frame: nd.array
click_stack: [[coordinate], [point_mode]]
"""
print("Click")
if point_mode == "Positive":
point = {"coord": [evt.index[0], evt.index[1]], "mode": 1}
else:
# TODO:add everything positive points
point = {"coord": [evt.index[0], evt.index[1]], "mode": 0}
if Seg_Tracker is None:
Seg_Tracker, _, _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
# get click prompts for sam to predict mask
click_prompt = get_click_prompt(click_stack, point)
# Refine acc to prompt
masked_frame = seg_acc_click(Seg_Tracker, click_prompt, origin_frame)
return Seg_Tracker, masked_frame, click_stack
def roll_back_sam_click(Seg_Tracker, origin_frame, point_mode, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, input_video, input_img_seq, frame_num, refine_idx, evt:gr.SelectData):
"""
Args:
origin_frame: nd.array
click_stack: [[coordinate], [point_mode]]
"""
print("Click")
if point_mode == "Positive":
point = {"coord": [evt.index[0], evt.index[1]], "mode": 1}
else:
# TODO:add everything positive points
point = {"coord": [evt.index[0], evt.index[1]], "mode": 0}
if Seg_Tracker is None:
Seg_Tracker, _, _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
# get click prompts for sam to predict mask
prompt = get_click_prompt(click_stack, point)
chosen_frame_show, curr_mask, ori_frame = res_by_num(input_video, input_img_seq, frame_num)
Seg_Tracker.curr_idx = refine_idx
predicted_mask, masked_frame = Seg_Tracker.seg_acc_click(
origin_frame=origin_frame,
coords=np.array(prompt["points_coord"]),
modes=np.array(prompt["points_mode"]),
multimask=prompt["multimask"],
)
curr_mask[curr_mask == refine_idx] = 0
curr_mask[predicted_mask != 0] = refine_idx
predicted_mask=curr_mask
Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask)
return Seg_Tracker, masked_frame, click_stack
def sam_stroke(Seg_Tracker, origin_frame, drawing_board, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side):
if Seg_Tracker is None:
Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
print("Stroke")
mask = drawing_board["mask"]
bbox = mask2bbox(mask[:, :, 0]) # bbox: [[x0, y0], [x1, y1]]
predicted_mask, masked_frame = Seg_Tracker.seg_acc_bbox(origin_frame, bbox)
Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask)
return Seg_Tracker, masked_frame, origin_frame
def gd_detect(Seg_Tracker, origin_frame, grounding_caption, box_threshold, text_threshold, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side):
if Seg_Tracker is None:
Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
print("Detect")
predicted_mask, annotated_frame= Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold)
Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask)
masked_frame = draw_mask(annotated_frame, predicted_mask)
return Seg_Tracker, masked_frame, origin_frame
def segment_everything(Seg_Tracker, aot_model, long_term_mem, max_len_long_term, origin_frame, sam_gap, max_obj_num, points_per_side):
if Seg_Tracker is None:
Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
print("Everything")
frame_idx = 0
with torch.cuda.amp.autocast():
pred_mask = Seg_Tracker.seg(origin_frame)
torch.cuda.empty_cache()
gc.collect()
Seg_Tracker.add_reference(origin_frame, pred_mask, frame_idx)
Seg_Tracker.first_frame_mask = pred_mask
masked_frame = draw_mask(origin_frame.copy(), pred_mask)
return Seg_Tracker, masked_frame
def add_new_object(Seg_Tracker):
prev_mask = Seg_Tracker.first_frame_mask
Seg_Tracker.update_origin_merged_mask(prev_mask)
Seg_Tracker.curr_idx += 1
print("Ready to add new object!")
return Seg_Tracker, [[], []]
def tracking_objects(Seg_Tracker, input_video, input_img_seq, fps, frame_num=0):
print("Start tracking !")
# pdb.set_trace()
# output_video, output_mask=tracking_objects_in_video(Seg_Tracker, input_video, input_img_seq, fps)
# pdb.set_trace()
return tracking_objects_in_video(Seg_Tracker, input_video, input_img_seq, fps, frame_num)
def res_by_num(input_video, input_img_seq, frame_num):
if input_video is not None:
video_name = os.path.basename(input_video).split('.')[0]
cap = cv2.VideoCapture(input_video)
for i in range(0,frame_num+1):
_, ori_frame = cap.read()
cap.release()
ori_frame = cv2.cvtColor(ori_frame, cv2.COLOR_BGR2RGB)
elif input_img_seq is not None:
file_name = input_img_seq.name.split('/')[-1].split('.')[0]
file_path = f'./assets/{file_name}'
video_name = file_name
imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)])
ori_frame = imgs_path[frame_num]
ori_frame = cv2.imread(ori_frame)
ori_frame = cv2.cvtColor(ori_frame, cv2.COLOR_BGR2RGB)
else:
return None, None, None
tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}'
output_masked_frame_dir = f'{tracking_result_dir}/{video_name}_masked_frames'
output_masked_frame_path = sorted([os.path.join(output_masked_frame_dir, img_name) for img_name in os.listdir(output_masked_frame_dir)])
output_mask_dir = f'{tracking_result_dir}/{video_name}_masks'
output_mask_path = sorted([os.path.join(output_mask_dir, img_name) for img_name in os.listdir(output_mask_dir)])
if len(output_masked_frame_path) == 0:
return None, None, None
else:
if frame_num >= len(output_masked_frame_path):
print("num out of frames range")
return None, None, None
else:
print("choose", frame_num, "to refine")
chosen_frame_show = output_masked_frame_path[frame_num]
chosen_frame_show = cv2.imread(chosen_frame_show)
chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
chosen_mask = output_mask_path[frame_num]
chosen_mask = cv2.imread(chosen_mask)
chosen_mask = Image.open(output_mask_path[frame_num]).convert('P')
chosen_mask = np.array(chosen_mask)
return chosen_frame_show, chosen_mask, ori_frame
def show_res_by_slider(input_video, input_img_seq, frame_per):
if input_video is not None:
video_name = os.path.basename(input_video).split('.')[0]
elif input_img_seq is not None:
file_name = input_img_seq.name.split('/')[-1].split('.')[0]
file_path = f'./assets/{file_name}'
video_name = file_name
else:
print("Not find output res")
return None, None
tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}'
output_masked_frame_dir = f'{tracking_result_dir}/{video_name}_masked_frames'
output_masked_frame_path = sorted([os.path.join(output_masked_frame_dir, img_name) for img_name in os.listdir(output_masked_frame_dir)])
total_frames_num = len(output_masked_frame_path)
if total_frames_num == 0:
print("Not find output res")
return None, None
else:
frame_num = math.floor(total_frames_num * frame_per / 100)
if frame_per == 100:
frame_num = frame_num -1
chosen_frame_show, _, _ = res_by_num(input_video, input_img_seq, frame_num)
return chosen_frame_show, frame_num
def choose_obj_to_refine(input_video, input_img_seq, Seg_Tracker, frame_num, evt:gr.SelectData):
chosen_frame_show, curr_mask, _ = res_by_num(input_video, input_img_seq, frame_num)
# curr_mask=Seg_Tracker.first_frame_mask
if curr_mask is not None and chosen_frame_show is not None:
idx = curr_mask[evt.index[1],evt.index[0]]
curr_idx_mask = np.where(curr_mask == idx, 1, 0).astype(np.uint8)
chosen_frame_show = draw_points(points=np.array([[evt.index[0],evt.index[1]]]), modes=np.array([[1]]), frame=chosen_frame_show)
chosen_frame_show = draw_outline(mask=curr_idx_mask, frame=chosen_frame_show)
print(idx)
return chosen_frame_show, idx
def show_chosen_idx_to_refine(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, input_video, input_img_seq, Seg_Tracker, frame_num, idx):
chosen_frame_show, curr_mask, ori_frame = res_by_num(input_video, input_img_seq, frame_num)
if Seg_Tracker is None:
print("reset aot args, new SegTracker")
Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, ori_frame)
# # reset aot args
# aot_args["model"] = aot_model
# aot_args["model_path"] = aot_model2ckpt[aot_model]
# aot_args["long_term_mem_gap"] = long_term_mem
# aot_args["max_len_long_term"] = max_len_long_term
# # reset sam args
# segtracker_args["sam_gap"] = sam_gap
# segtracker_args["max_obj_num"] = max_obj_num
# sam_args["generator_args"]["points_per_side"] = points_per_side
# Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
Seg_Tracker.restart_tracker()
Seg_Tracker.curr_idx = 1
Seg_Tracker.object_idx = 1
Seg_Tracker.origin_merged_mask = None
Seg_Tracker.first_frame_mask = None
Seg_Tracker.reference_objs_list=[]
Seg_Tracker.everything_points = []
Seg_Tracker.everything_labels = []
Seg_Tracker.sam.have_embedded = False
Seg_Tracker.sam.interactive_predictor.features = None
return ori_frame, Seg_Tracker, ori_frame, [[], []], ""
def seg_track_app():
##########################################################
###################### Front-end ########################
##########################################################
app = gr.Blocks()
with app:
gr.Markdown(
'''
<div style="text-align:center;">
<span style="font-size:3em; font-weight:bold;">Segment and Track Anything(SAM-Track)</span>
</div>
'''
)
click_stack = gr.State([[],[]]) # Storage clicks status
origin_frame = gr.State(None)
Seg_Tracker = gr.State(None)
current_frame_num = gr.State(None)
refine_idx = gr.State(None)
frame_num = gr.State(None)
aot_model = gr.State(None)
sam_gap = gr.State(None)
points_per_side = gr.State(None)
max_obj_num = gr.State(None)
with gr.Row():
# video input
with gr.Column(scale=0.5):
tab_video_input = gr.Tab(label="Video type input")
with tab_video_input:
input_video = gr.Video(label='Input video').style(height=550)
tab_img_seq_input = gr.Tab(label="Image-Seq type input")
with tab_img_seq_input:
with gr.Row():
input_img_seq = gr.File(label='Input Image-Seq').style(height=550)
with gr.Column(scale=0.25):
extract_button = gr.Button(value="extract")
fps = gr.Slider(label='fps', minimum=5, maximum=50, value=8, step=1)
input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)
tab_everything = gr.Tab(label="Everything")
with tab_everything:
with gr.Row():
seg_every_first_frame = gr.Button(value="Segment everything for first frame", interactive=True)
point_mode = gr.Radio(
choices=["Positive"],
value="Positive",
label="Point Prompt",
interactive=True)
every_undo_but = gr.Button(
value="Undo",
interactive=True
)
# every_reset_but = gr.Button(
# value="Reset",
# interactive=True
# )
tab_click = gr.Tab(label="Click")
with tab_click:
with gr.Row():
point_mode = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
interactive=True)
# args for modify and tracking
click_undo_but = gr.Button(
value="Undo",
interactive=True
)
# click_reset_but = gr.Button(
# value="Reset",
# interactive=True
# )
tab_stroke = gr.Tab(label="Stroke")
with tab_stroke:
drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True)
with gr.Row():
seg_acc_stroke = gr.Button(value="Segment", interactive=True)
# stroke_reset_but = gr.Button(
# value="Reset",
# interactive=True
# )
tab_text = gr.Tab(label="Text")
with tab_text:
grounding_caption = gr.Textbox(label="Detection Prompt")
detect_button = gr.Button(value="Detect")
with gr.Accordion("Advanced options", open=False):
with gr.Row():
with gr.Column(scale=0.5):
box_threshold = gr.Slider(
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
with gr.Column(scale=0.5):
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
with gr.Row():
with gr.Column(scale=0.5):
with gr.Tab(label="SegTracker Args"):
# args for tracking in video do segment-everthing
points_per_side = gr.Slider(
label = "points_per_side",
minimum= 1,
step = 1,
maximum=100,
value=16,
interactive=True
)
sam_gap = gr.Slider(
label='sam_gap',
minimum = 1,
step=1,
maximum = 9999,
value=100,
interactive=True,
)
max_obj_num = gr.Slider(
label='max_obj_num',
minimum = 50,
step=1,
maximum = 300,
value=255,
interactive=True
)
with gr.Accordion("aot advanced options", open=False):
aot_model = gr.Dropdown(
label="aot_model",
choices = [
"deaotb",
"deaotl",
"r50_deaotl"
],
value = "r50_deaotl",
interactive=True,
)
long_term_mem = gr.Slider(label="long term memory gap", minimum=1, maximum=9999, value=9999, step=1)
max_len_long_term = gr.Slider(label="max len of long term memory", minimum=1, maximum=9999, value=9999, step=1)
with gr.Column():
new_object_button = gr.Button(
value="Add new object",
interactive=True
)
reset_button = gr.Button(
value="Reset",
interactive=True,
)
track_for_video = gr.Button(
value="Start Tracking",
interactive=True,
)
with gr.Column(scale=0.5):
# output_video = gr.Video(label='Output video').style(height=550)
output_video = gr.File(label="Predicted video")
output_mask = gr.File(label="Predicted masks")
with gr.Row():
with gr.Column(scale=1):
with gr.Accordion("roll back options", open=False):
# tab_show_res = gr.Tab(label="Segment result of all frames")
# with tab_show_res:
output_res = gr.Image(label='Segment result of all frames').style(height=550)
frame_per = gr.Slider(
label = "Percentage of Frames Viewed",
minimum= 0.0,
maximum= 100.0,
step=0.01,
value=0.0,
)
frame_per.release(show_res_by_slider, inputs=[input_video, input_img_seq, frame_per], outputs=[output_res, frame_num])
roll_back_button = gr.Button(value="Choose this mask to refine")
refine_res = gr.Image(label='Refine masks').style(height=550)\
tab_roll_back_click = gr.Tab(label="Click")
with tab_roll_back_click:
with gr.Row():
roll_back_point_mode = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
interactive=True)
# args for modify and tracking
roll_back_click_undo_but = gr.Button(
value="Undo",
interactive=True
)
roll_back_track_for_video = gr.Button(
value="Start tracking to refine",
interactive=True,
)
##########################################################
###################### back-end #########################
##########################################################
# listen to the input_video to get the first frame of video
input_video.change(
fn=get_meta_from_video,
inputs=[
input_video
],
outputs=[
input_first_frame, origin_frame, drawing_board, grounding_caption
]
)
# listen to the input_img_seq to get the first frame of video
input_img_seq.change(
fn=get_meta_from_img_seq,
inputs=[
input_img_seq
],
outputs=[
input_first_frame, origin_frame, drawing_board, grounding_caption
]
)
#-------------- Input compont -------------
tab_video_input.select(
fn = clean,
inputs=[],
outputs=[
input_video,
input_img_seq,
Seg_Tracker,
input_first_frame,
origin_frame,
drawing_board,
click_stack,
]
)
tab_img_seq_input.select(
fn = clean,
inputs=[],
outputs=[
input_video,
input_img_seq,
Seg_Tracker,
input_first_frame,
origin_frame,
drawing_board,
click_stack,
]
)
extract_button.click(
fn=get_meta_from_img_seq,
inputs=[
input_img_seq
],
outputs=[
input_first_frame, origin_frame, drawing_board, grounding_caption
]
)
# ------------------- Interactive component -----------------
# listen to the tab to init SegTracker
tab_everything.select(
fn=init_SegTracker,
inputs=[
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
origin_frame
],
outputs=[
Seg_Tracker, input_first_frame, click_stack, grounding_caption
],
queue=False,
)
tab_click.select(
fn=init_SegTracker,
inputs=[
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
origin_frame
],
outputs=[
Seg_Tracker, input_first_frame, click_stack, grounding_caption
],
queue=False,
)
tab_stroke.select(
fn=init_SegTracker_Stroke,
inputs=[
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
origin_frame,
],
outputs=[
Seg_Tracker, input_first_frame, click_stack, drawing_board
],
queue=False,
)
tab_text.select(
fn=init_SegTracker,
inputs=[
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
origin_frame
],
outputs=[
Seg_Tracker, input_first_frame, click_stack, grounding_caption
],
queue=False,
)
# Use SAM to segment everything for the first frame of video
seg_every_first_frame.click(
fn=segment_everything,
inputs=[
Seg_Tracker,
aot_model,
long_term_mem,
max_len_long_term,
origin_frame,
sam_gap,
max_obj_num,
points_per_side,
],
outputs=[
Seg_Tracker,
input_first_frame,
],
)
# Interactively modify the mask acc click
input_first_frame.select(
fn=sam_click,
inputs=[
Seg_Tracker, origin_frame, point_mode, click_stack,
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
],
outputs=[
Seg_Tracker, input_first_frame, click_stack
]
)
# Interactively segment acc stroke
seg_acc_stroke.click(
fn=sam_stroke,
inputs=[
Seg_Tracker, origin_frame, drawing_board,
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
],
outputs=[
Seg_Tracker, input_first_frame, drawing_board
]
)
# Use grounding-dino to detect object
detect_button.click(
fn=gd_detect,
inputs=[
Seg_Tracker, origin_frame, grounding_caption, box_threshold, text_threshold,
aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side
],
outputs=[
Seg_Tracker, input_first_frame
]
)
# Add new object
new_object_button.click(
fn=add_new_object,
inputs=
[
Seg_Tracker
],
outputs=
[
Seg_Tracker, click_stack
]
)
# Track object in video
track_for_video.click(
fn=tracking_objects,
inputs=[
Seg_Tracker,
input_video,
input_img_seq,
fps,
],
outputs=[
output_video, output_mask
]
)
# ----------------- Refine Mask ---------------------------
output_res.select(
fn = choose_obj_to_refine,
inputs=[
input_video, input_img_seq, Seg_Tracker, frame_num
],
outputs=[output_res, refine_idx]
)
roll_back_button.click(
fn=show_chosen_idx_to_refine,
inputs=[
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
input_video, input_img_seq, Seg_Tracker, frame_num, refine_idx
],
outputs=[
refine_res, Seg_Tracker, origin_frame, click_stack, grounding_caption
],
queue=False,
show_progress=False
)
roll_back_click_undo_but.click(
fn = roll_back_undo_click_stack_and_refine_seg,
inputs=[
Seg_Tracker, origin_frame, click_stack,
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
input_video, input_img_seq, frame_num, refine_idx
],
outputs=[
Seg_Tracker, refine_res, click_stack
]
)
refine_res.select(
fn=roll_back_sam_click,
inputs=[
Seg_Tracker, origin_frame, roll_back_point_mode, click_stack,
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
input_video, input_img_seq, frame_num, refine_idx
],
outputs=[
Seg_Tracker, refine_res, click_stack
]
)
# Track object in video
roll_back_track_for_video.click(
fn=tracking_objects,
inputs=[
Seg_Tracker,
input_video,
input_img_seq,
fps, frame_num
],
outputs=[
output_video, output_mask
]
)
# ----------------- Reset and Undo ---------------------------
# Rest
reset_button.click(
fn=init_SegTracker,
inputs=[
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
origin_frame
],
outputs=[
Seg_Tracker, input_first_frame, click_stack, grounding_caption
],
queue=False,
show_progress=False
)
# every_reset_but.click(
# fn=init_SegTracker,
# inputs=[
# aot_model,
# sam_gap,
# max_obj_num,
# points_per_side,
# origin_frame
# ],
# outputs=[
# Seg_Tracker, input_first_frame, click_stack, grounding_caption
# ],
# queue=False,
# show_progress=False
# )
# click_reset_but.click(
# fn=init_SegTracker,
# inputs=[
# aot_model,
# sam_gap,
# max_obj_num,
# points_per_side,
# origin_frame
# ],
# outputs=[
# Seg_Tracker, input_first_frame, click_stack, grounding_caption
# ],
# queue=False,
# show_progress=False
# )
# stroke_reset_but.click(
# fn=init_SegTracker_Stroke,
# inputs=[
# aot_model,
# sam_gap,
# max_obj_num,
# points_per_side,
# origin_frame,
# ],
# outputs=[
# Seg_Tracker, input_first_frame, click_stack, drawing_board
# ],
# queue=False,
# show_progress=False
# )
# Undo click
click_undo_but.click(
fn = undo_click_stack_and_refine_seg,
inputs=[
Seg_Tracker, origin_frame, click_stack,
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
],
outputs=[
Seg_Tracker, input_first_frame, click_stack
]
)
every_undo_but.click(
fn = undo_click_stack_and_refine_seg,
inputs=[
Seg_Tracker, origin_frame, click_stack,
aot_model,
long_term_mem,
max_len_long_term,
sam_gap,
max_obj_num,
points_per_side,
],
outputs=[
Seg_Tracker, input_first_frame, click_stack
]
)
with gr.Tab(label='Video example'):
gr.Examples(
examples=[
# os.path.join(os.path.dirname(__file__), "assets", "840_iSXIa0hE8Ek.mp4"),
os.path.join(os.path.dirname(__file__), "assets", "blackswan.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "bear.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "camel.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "skate-park.mp4"),
# os.path.join(os.path.dirname(__file__), "assets", "swing.mp4"),
],
inputs=[input_video],
)
with gr.Tab(label='Image-seq expamle'):
gr.Examples(
examples=[
os.path.join(os.path.dirname(__file__), "assets", "840_iSXIa0hE8Ek.zip"),
],
inputs=[input_img_seq],
)
app.queue(concurrency_count=1)
app.launch(debug=True, enable_queue=True, share=True)
if __name__ == "__main__":
seg_track_app()