hugoycj commited on
Commit
77033ff
1 Parent(s): a48a1c1

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *venv
2
+ saved_videos
3
+ flagged
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from cotracker.utils.visualizer import Visualizer, read_video_from_path
8
+ from cotracker.predictor import CoTrackerPredictor
9
+
10
+ checkpoint='./checkpoints/cotracker_stride_4_wind_8.pth'
11
+ def cotracker(video_path: str, grid_size: int, grid_query_frame: int, backward_tracking: bool):
12
+ # load the input video frame by frame
13
+ video = read_video_from_path(video_path)
14
+ video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
15
+ model = CoTrackerPredictor(checkpoint=checkpoint)
16
+ if torch.cuda.is_available():
17
+ model = model.cuda()
18
+ video = video.cuda()
19
+ else:
20
+ print("CUDA is not available!")
21
+
22
+ pred_tracks, pred_visibility = model(
23
+ video,
24
+ grid_size=grid_size,
25
+ grid_query_frame=grid_query_frame,
26
+ backward_tracking=backward_tracking,
27
+ )
28
+ print("computed")
29
+
30
+ # save a video with predicted tracks
31
+ seq_name = video_path.split("/")[-1]
32
+ vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
33
+ vis.visualize(video, pred_tracks, query_frame=grid_query_frame)
34
+
35
+ return "./saved_videos/video_pred_track.mp4"
36
+
37
+ iface = gr.Interface(
38
+ fn=cotracker,
39
+ inputs=[
40
+ gr.inputs.Video(label='video', type='mp4'),
41
+ gr.inputs.Slider(minimum=0, maximum=20, step=1, default=10, label="Grid Size"),
42
+ gr.inputs.Slider(minimum=0, maximum=10, step=1, default=0, label="Grid Query Frame"),
43
+ gr.inputs.Checkbox(label="Backward Tracking"),
44
+ ],
45
+ outputs=gr.outputs.Video(label="Output")
46
+ )
47
+ iface.queue()
48
+ iface.launch()
checkpoints/cotracker_stride_4_wind_8.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbb9006c0dc89479c34993ac23d37ae0af13deac18be6028009d7483754b2fc3
3
+ size 96660885
packages.txt ADDED
File without changes
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==1.13.0
2
+ torchvision==0.14.0
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra-core
2
+ omegaconf
3
+ einops
4
+ timm
5
+ tqdm
6
+ opencv-python
7
+ matplotlib
8
+ moviepy
9
+ flow_vis
10
+ git+https://github.com/facebookresearch/co-tracker.git