Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
11554c5
1
Parent(s):
9e24bfb
complete task 1
Browse files- .gitignore +4 -0
- app.py +203 -4
- config.yaml +138 -0
- export.py +9 -0
- requirements.txt +3 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface
|
2 |
+
.DS_Store
|
3 |
+
data/
|
4 |
+
__pycache__/
|
app.py
CHANGED
@@ -1,7 +1,206 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import spaces
|
3 |
import gradio as gr
|
4 |
+
import imageio
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from algorithms.dfot import DFoTVideoPose
|
9 |
+
from utils.ckpt_utils import download_pretrained
|
10 |
+
from datasets.video.utils.io import read_video
|
11 |
+
from datasets.video import RealEstate10KAdvancedVideoDataset
|
12 |
+
from export import export_to_video
|
13 |
|
14 |
+
DATASET_DIR = Path("data/real-estate-10k-tiny")
|
15 |
+
LONG_LENGTH = 20 # seconds
|
16 |
|
17 |
+
metadata = torch.load(DATASET_DIR / "metadata" / "test.pt", weights_only=False)
|
18 |
+
video_list = [
|
19 |
+
read_video(path).permute(0, 3, 1, 2) / 255.0 for path in metadata["video_paths"]
|
20 |
+
]
|
21 |
+
first_frame_list = [
|
22 |
+
(video[0] * 255).permute(1, 2, 0).numpy().clip(0, 255).astype("uint8")
|
23 |
+
for video in video_list
|
24 |
+
]
|
25 |
+
poses_list = [
|
26 |
+
torch.cat(
|
27 |
+
[
|
28 |
+
poses[:, :4],
|
29 |
+
poses[:, 6:],
|
30 |
+
],
|
31 |
+
dim=-1,
|
32 |
+
).to(torch.float32)
|
33 |
+
for poses in (
|
34 |
+
torch.load(DATASET_DIR / "test_poses" / f"{path.stem}.pt")
|
35 |
+
for path in metadata["video_paths"]
|
36 |
+
)
|
37 |
+
]
|
38 |
+
|
39 |
+
# pylint: disable-next=no-value-for-parameter
|
40 |
+
dfot = DFoTVideoPose.load_from_checkpoint(
|
41 |
+
checkpoint_path=download_pretrained("pretrained:DFoT_RE10K.ckpt"),
|
42 |
+
cfg=OmegaConf.load("config.yaml"),
|
43 |
+
).eval()
|
44 |
+
dfot.to("cuda")
|
45 |
+
|
46 |
+
def prepare_long_gt_video(idx: int):
|
47 |
+
video = video_list[idx]
|
48 |
+
indices = torch.linspace(0, video.size(0) - 1, LONG_LENGTH * 10, dtype=torch.long)
|
49 |
+
return export_to_video(video[indices], fps=10)
|
50 |
+
|
51 |
+
@spaces.GPU(duration=120)
|
52 |
+
@torch.no_grad()
|
53 |
+
def single_image_to_long_video(idx: int, guidance_scale: float, fps: int, progress=gr.Progress(track_tqdm=True)):
|
54 |
+
video = video_list[idx]
|
55 |
+
poses = poses_list[idx]
|
56 |
+
indices = torch.linspace(0, video.size(0) - 1, LONG_LENGTH * fps, dtype=torch.long)
|
57 |
+
xs = video[indices].unsqueeze(0).to("cuda")
|
58 |
+
conditions = poses[indices].unsqueeze(0).to("cuda")
|
59 |
+
dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale
|
60 |
+
dfot.cfg.tasks.prediction.keyframe_density = 0.6 / fps
|
61 |
+
# dfot.cfg.tasks.interpolation.history_guidance.guidance_scale = guidance_scale
|
62 |
+
gen_video = dfot._unnormalize_x(
|
63 |
+
dfot._predict_videos(
|
64 |
+
dfot._normalize_x(xs),
|
65 |
+
conditions,
|
66 |
+
)
|
67 |
+
)
|
68 |
+
return export_to_video(gen_video[0].detach().cpu(), fps=fps)
|
69 |
+
|
70 |
+
|
71 |
+
# Create the Gradio Blocks
|
72 |
+
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
73 |
+
gr.HTML(
|
74 |
+
"""
|
75 |
+
<style>
|
76 |
+
[data-tab-id="task-1"], [data-tab-id="task-2"], [data-tab-id="task-3"] {
|
77 |
+
font-size: 16px !important;
|
78 |
+
font-weight: bold;
|
79 |
+
}
|
80 |
+
</style>
|
81 |
+
"""
|
82 |
+
)
|
83 |
+
|
84 |
+
gr.Markdown("# Diffusion Forcing Transformer and History Guidance")
|
85 |
+
gr.Markdown(
|
86 |
+
"### Official Interactive Demo for [_History-guided Video Diffusion_](todo)"
|
87 |
+
)
|
88 |
+
with gr.Row():
|
89 |
+
gr.Button(value="🌐 Website", link="todo")
|
90 |
+
gr.Button(value="📄 Paper", link="https://boyuan.space/history-guidance")
|
91 |
+
gr.Button(
|
92 |
+
value="💻 Code",
|
93 |
+
link="https://github.com/kwsong0113/diffusion-forcing-transformer",
|
94 |
+
)
|
95 |
+
gr.Button(
|
96 |
+
value="🤗 Pretrained Models", link="https://huggingface.co/kiwhansong/DFoT"
|
97 |
+
)
|
98 |
+
|
99 |
+
with gr.Tab("Single Image → Long Video", id="task-1"):
|
100 |
+
gr.Markdown(
|
101 |
+
"""
|
102 |
+
## Demo 2: Single Image → Long Video
|
103 |
+
> #### **TL;DR:** _Diffusion Forcing Transformer, with History Guidance, can stably generate long videos, via sliding window rollouts and interpolation._
|
104 |
+
"""
|
105 |
+
)
|
106 |
+
|
107 |
+
stage = gr.State(value="Selection")
|
108 |
+
selected_index = gr.State(value=None)
|
109 |
+
|
110 |
+
@gr.render(inputs=[stage, selected_index])
|
111 |
+
def render_stage(s, idx):
|
112 |
+
match s:
|
113 |
+
case "Selection":
|
114 |
+
image_gallery = gr.Gallery(
|
115 |
+
value=first_frame_list,
|
116 |
+
label="Select an image to animate",
|
117 |
+
columns=[8],
|
118 |
+
selected_index=idx,
|
119 |
+
)
|
120 |
+
|
121 |
+
@image_gallery.select(inputs=None, outputs=selected_index)
|
122 |
+
def update_selection(selection: gr.SelectData):
|
123 |
+
return selection.index
|
124 |
+
|
125 |
+
select_button = gr.Button("Select")
|
126 |
+
|
127 |
+
@select_button.click(inputs=selected_index, outputs=stage)
|
128 |
+
def move_to_generation(idx: int):
|
129 |
+
if idx is None:
|
130 |
+
gr.Warning("Image not selected!")
|
131 |
+
return "Selection"
|
132 |
+
else:
|
133 |
+
return "Generation"
|
134 |
+
|
135 |
+
case "Generation":
|
136 |
+
with gr.Row():
|
137 |
+
gr.Image(value=first_frame_list[idx], label="Input Image")
|
138 |
+
# gr.Video(value=metadata["video_paths"][idx], label="Ground Truth Video")
|
139 |
+
gr.Video(value=prepare_long_gt_video(idx), label="Ground Truth Video")
|
140 |
+
video = gr.Video(label="Generated Video")
|
141 |
+
|
142 |
+
with gr.Column():
|
143 |
+
guidance_scale = gr.Slider(
|
144 |
+
minimum=1,
|
145 |
+
maximum=6,
|
146 |
+
value=4,
|
147 |
+
step=0.5,
|
148 |
+
label="History Guidance Scale",
|
149 |
+
info="Without history guidance: 1.0; Recommended: 4.0",
|
150 |
+
interactive=True,
|
151 |
+
)
|
152 |
+
fps = gr.Slider(
|
153 |
+
minimum=1,
|
154 |
+
maximum=10,
|
155 |
+
value=4,
|
156 |
+
step=1,
|
157 |
+
label="FPS",
|
158 |
+
info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video",
|
159 |
+
interactive=True,
|
160 |
+
)
|
161 |
+
generate_button = gr.Button("Generate Video").click(
|
162 |
+
fn=single_image_to_long_video,
|
163 |
+
inputs=[selected_index, guidance_scale, fps],
|
164 |
+
outputs=video,
|
165 |
+
)
|
166 |
+
# def generate_video(idx: int):
|
167 |
+
# gr.Video(value=single_image_to_long_video(idx))
|
168 |
+
|
169 |
+
# Function to update the state with the selected index
|
170 |
+
|
171 |
+
# def show_warning(selection: gr.SelectData):
|
172 |
+
# gr.Warning(f"Your choice is #{selection.index}, with image: {selection.value['image']['path']}!")
|
173 |
+
|
174 |
+
# # image_gallery.select(fn=show_warning, inputs=None)
|
175 |
+
|
176 |
+
# # Show the generate button only if an image is selected
|
177 |
+
# selected_index.change(fn=lambda idx: idx is not None, inputs=selected_index, outputs=generate_button)
|
178 |
+
|
179 |
+
with gr.Tab("Any Images → Video", id="task-2"):
|
180 |
+
gr.Markdown(
|
181 |
+
"""
|
182 |
+
## Demo 1: Any Images → Video
|
183 |
+
> #### **TL;DR:** _Diffusion Forcing Transformer is a flexible model that can generate videos given variable number of context frames._
|
184 |
+
"""
|
185 |
+
)
|
186 |
+
input_text_1 = gr.Textbox(
|
187 |
+
lines=2, placeholder="Enter text for Video Model 1..."
|
188 |
+
)
|
189 |
+
output_video_1 = gr.Video()
|
190 |
+
generate_button_1 = gr.Button("Generate Video")
|
191 |
+
|
192 |
+
with gr.Tab("Single Image → Extremely Long Video", id="task-3"):
|
193 |
+
gr.Markdown(
|
194 |
+
"""
|
195 |
+
## Demo 3: Single Image → Extremely Long Video
|
196 |
+
> #### **TL;DR:** _Diffusion Forcing Transformer is a flexible model that can generate videos given **variable number of context frames**._
|
197 |
+
"""
|
198 |
+
)
|
199 |
+
input_text_2 = gr.Textbox(
|
200 |
+
lines=2, placeholder="Enter text for Video Model 2..."
|
201 |
+
)
|
202 |
+
output_video_2 = gr.Video()
|
203 |
+
generate_button_2 = gr.Button("Generate Video")
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
demo.launch()
|
config.yaml
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
debug: False
|
2 |
+
lr: 5e-5
|
3 |
+
backbone:
|
4 |
+
name: u_vit3d_pose
|
5 |
+
channels:
|
6 |
+
- 128
|
7 |
+
- 256
|
8 |
+
- 576
|
9 |
+
- 1152
|
10 |
+
emb_channels: 1024
|
11 |
+
patch_size: 2
|
12 |
+
block_types:
|
13 |
+
- ResBlock
|
14 |
+
- ResBlock
|
15 |
+
- TransformerBlock
|
16 |
+
- TransformerBlock
|
17 |
+
block_dropouts:
|
18 |
+
- 0.0
|
19 |
+
- 0.0
|
20 |
+
- 0.1
|
21 |
+
- 0.1
|
22 |
+
num_updown_blocks:
|
23 |
+
- 3
|
24 |
+
- 3
|
25 |
+
- 6
|
26 |
+
num_mid_blocks: 20
|
27 |
+
num_heads: 9
|
28 |
+
pos_emb_type: rope
|
29 |
+
use_checkpointing:
|
30 |
+
- false
|
31 |
+
- false
|
32 |
+
- false
|
33 |
+
- true
|
34 |
+
conditioning:
|
35 |
+
dim: null
|
36 |
+
external_cond_dropout: 0.1
|
37 |
+
use_fourier_noise_embedding: true
|
38 |
+
x_shape: [3, 256, 256]
|
39 |
+
max_frames: 8
|
40 |
+
n_frames: 8
|
41 |
+
frame_skip: 1
|
42 |
+
context_frames: 1
|
43 |
+
latent:
|
44 |
+
enable: False
|
45 |
+
type: pre_sample
|
46 |
+
suffix: null
|
47 |
+
downsampling_factor: [1, 8]
|
48 |
+
num_channels: 4
|
49 |
+
data_mean: [[[0.577]], [[0.517]], [[0.461]]]
|
50 |
+
data_std: [[[0.249]], [[0.249]], [[0.268]]]
|
51 |
+
external_cond_dim: 16
|
52 |
+
external_cond_stack: False
|
53 |
+
external_cond_processing: null
|
54 |
+
compile: false
|
55 |
+
weight_decay: 0.01
|
56 |
+
optimizer_beta:
|
57 |
+
- 0.9
|
58 |
+
- 0.99
|
59 |
+
lr_scheduler:
|
60 |
+
name: constant_with_warmup
|
61 |
+
num_warmup_steps: 10000
|
62 |
+
num_training_steps: 550000
|
63 |
+
noise_level: random_independent
|
64 |
+
uniform_future:
|
65 |
+
enabled: false
|
66 |
+
fixed_context:
|
67 |
+
enabled: false
|
68 |
+
indices: null
|
69 |
+
dropout: 0
|
70 |
+
variable_context:
|
71 |
+
enabled: false
|
72 |
+
prob: 0
|
73 |
+
dropout: 0
|
74 |
+
chunk_size: -1
|
75 |
+
scheduling_matrix: full_sequence
|
76 |
+
replacement: noisy_scale
|
77 |
+
diffusion:
|
78 |
+
is_continuous: true
|
79 |
+
timesteps: 1000
|
80 |
+
beta_schedule: cosine_simple_diffusion
|
81 |
+
schedule_fn_kwargs:
|
82 |
+
shift: 1.0
|
83 |
+
shifted: 0.125
|
84 |
+
interpolated: false
|
85 |
+
use_causal_mask: false
|
86 |
+
clip_noise: 20.0
|
87 |
+
objective: pred_v
|
88 |
+
loss_weighting:
|
89 |
+
strategy: sigmoid
|
90 |
+
snr_clip: 5.0
|
91 |
+
cum_snr_decay: 0.9
|
92 |
+
sigmoid_bias: -1.0
|
93 |
+
sampling_timesteps: 50
|
94 |
+
ddim_sampling_eta: 0.0
|
95 |
+
reconstruction_guidance: 0.0
|
96 |
+
training_schedule:
|
97 |
+
name: cosine
|
98 |
+
shift: 0.125
|
99 |
+
precond_scale: 0.125
|
100 |
+
vae:
|
101 |
+
pretrained_path: null
|
102 |
+
pretrained_kwargs: {}
|
103 |
+
use_fp16: true
|
104 |
+
batch_size: 2
|
105 |
+
checkpoint:
|
106 |
+
reset_optimizer: false
|
107 |
+
strict: true
|
108 |
+
tasks:
|
109 |
+
prediction:
|
110 |
+
enabled: true
|
111 |
+
history_guidance:
|
112 |
+
name: stabilized_vanilla
|
113 |
+
guidance_scale: 4.0
|
114 |
+
stabilization_level: 0.02
|
115 |
+
visualize: False
|
116 |
+
keyframe_density: null
|
117 |
+
sliding_context_len: null
|
118 |
+
interpolation:
|
119 |
+
enabled: false
|
120 |
+
history_guidance:
|
121 |
+
name: vanilla
|
122 |
+
guidance_scale: 1.5
|
123 |
+
visualize: False
|
124 |
+
max_batch_size: 1
|
125 |
+
logging:
|
126 |
+
deterministic: null
|
127 |
+
loss_freq: 100
|
128 |
+
grad_norm_freq: 100
|
129 |
+
max_num_videos: 256
|
130 |
+
n_metrics_frames: null
|
131 |
+
metrics: []
|
132 |
+
metrics_batch_size: 16
|
133 |
+
sanity_generation: false
|
134 |
+
raw_dir: null
|
135 |
+
camera_pose_conditioning:
|
136 |
+
normalize_by: first
|
137 |
+
bound: null
|
138 |
+
type: ray_encoding
|
export.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
import torch
|
3 |
+
from torch import Tensor
|
4 |
+
from torchvision.io import write_video
|
5 |
+
|
6 |
+
def export_to_video(tensor: Tensor, fps: int = 10) -> str:
|
7 |
+
path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
8 |
+
write_video(path, (tensor.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8), fps=fps)
|
9 |
+
return path
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# gradio
|
2 |
+
# spaces
|
3 |
+
git+https://github.com/kwsong0113/dfot-test.git@release#egg=dfot # FIXME: change to the official repo
|