kiwhansong commited on
Commit
11554c5
·
1 Parent(s): 9e24bfb

complete task 1

Browse files
Files changed (5) hide show
  1. .gitignore +4 -0
  2. app.py +203 -4
  3. config.yaml +138 -0
  4. export.py +9 -0
  5. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ huggingface
2
+ .DS_Store
3
+ data/
4
+ __pycache__/
app.py CHANGED
@@ -1,7 +1,206 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import spaces
3
  import gradio as gr
4
+ import imageio
5
+ import torch
6
+ from PIL import Image
7
+ from omegaconf import OmegaConf
8
+ from algorithms.dfot import DFoTVideoPose
9
+ from utils.ckpt_utils import download_pretrained
10
+ from datasets.video.utils.io import read_video
11
+ from datasets.video import RealEstate10KAdvancedVideoDataset
12
+ from export import export_to_video
13
 
14
+ DATASET_DIR = Path("data/real-estate-10k-tiny")
15
+ LONG_LENGTH = 20 # seconds
16
 
17
+ metadata = torch.load(DATASET_DIR / "metadata" / "test.pt", weights_only=False)
18
+ video_list = [
19
+ read_video(path).permute(0, 3, 1, 2) / 255.0 for path in metadata["video_paths"]
20
+ ]
21
+ first_frame_list = [
22
+ (video[0] * 255).permute(1, 2, 0).numpy().clip(0, 255).astype("uint8")
23
+ for video in video_list
24
+ ]
25
+ poses_list = [
26
+ torch.cat(
27
+ [
28
+ poses[:, :4],
29
+ poses[:, 6:],
30
+ ],
31
+ dim=-1,
32
+ ).to(torch.float32)
33
+ for poses in (
34
+ torch.load(DATASET_DIR / "test_poses" / f"{path.stem}.pt")
35
+ for path in metadata["video_paths"]
36
+ )
37
+ ]
38
+
39
+ # pylint: disable-next=no-value-for-parameter
40
+ dfot = DFoTVideoPose.load_from_checkpoint(
41
+ checkpoint_path=download_pretrained("pretrained:DFoT_RE10K.ckpt"),
42
+ cfg=OmegaConf.load("config.yaml"),
43
+ ).eval()
44
+ dfot.to("cuda")
45
+
46
+ def prepare_long_gt_video(idx: int):
47
+ video = video_list[idx]
48
+ indices = torch.linspace(0, video.size(0) - 1, LONG_LENGTH * 10, dtype=torch.long)
49
+ return export_to_video(video[indices], fps=10)
50
+
51
+ @spaces.GPU(duration=120)
52
+ @torch.no_grad()
53
+ def single_image_to_long_video(idx: int, guidance_scale: float, fps: int, progress=gr.Progress(track_tqdm=True)):
54
+ video = video_list[idx]
55
+ poses = poses_list[idx]
56
+ indices = torch.linspace(0, video.size(0) - 1, LONG_LENGTH * fps, dtype=torch.long)
57
+ xs = video[indices].unsqueeze(0).to("cuda")
58
+ conditions = poses[indices].unsqueeze(0).to("cuda")
59
+ dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale
60
+ dfot.cfg.tasks.prediction.keyframe_density = 0.6 / fps
61
+ # dfot.cfg.tasks.interpolation.history_guidance.guidance_scale = guidance_scale
62
+ gen_video = dfot._unnormalize_x(
63
+ dfot._predict_videos(
64
+ dfot._normalize_x(xs),
65
+ conditions,
66
+ )
67
+ )
68
+ return export_to_video(gen_video[0].detach().cpu(), fps=fps)
69
+
70
+
71
+ # Create the Gradio Blocks
72
+ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
73
+ gr.HTML(
74
+ """
75
+ <style>
76
+ [data-tab-id="task-1"], [data-tab-id="task-2"], [data-tab-id="task-3"] {
77
+ font-size: 16px !important;
78
+ font-weight: bold;
79
+ }
80
+ </style>
81
+ """
82
+ )
83
+
84
+ gr.Markdown("# Diffusion Forcing Transformer and History Guidance")
85
+ gr.Markdown(
86
+ "### Official Interactive Demo for [_History-guided Video Diffusion_](todo)"
87
+ )
88
+ with gr.Row():
89
+ gr.Button(value="🌐 Website", link="todo")
90
+ gr.Button(value="📄 Paper", link="https://boyuan.space/history-guidance")
91
+ gr.Button(
92
+ value="💻 Code",
93
+ link="https://github.com/kwsong0113/diffusion-forcing-transformer",
94
+ )
95
+ gr.Button(
96
+ value="🤗 Pretrained Models", link="https://huggingface.co/kiwhansong/DFoT"
97
+ )
98
+
99
+ with gr.Tab("Single Image → Long Video", id="task-1"):
100
+ gr.Markdown(
101
+ """
102
+ ## Demo 2: Single Image → Long Video
103
+ > #### **TL;DR:** _Diffusion Forcing Transformer, with History Guidance, can stably generate long videos, via sliding window rollouts and interpolation._
104
+ """
105
+ )
106
+
107
+ stage = gr.State(value="Selection")
108
+ selected_index = gr.State(value=None)
109
+
110
+ @gr.render(inputs=[stage, selected_index])
111
+ def render_stage(s, idx):
112
+ match s:
113
+ case "Selection":
114
+ image_gallery = gr.Gallery(
115
+ value=first_frame_list,
116
+ label="Select an image to animate",
117
+ columns=[8],
118
+ selected_index=idx,
119
+ )
120
+
121
+ @image_gallery.select(inputs=None, outputs=selected_index)
122
+ def update_selection(selection: gr.SelectData):
123
+ return selection.index
124
+
125
+ select_button = gr.Button("Select")
126
+
127
+ @select_button.click(inputs=selected_index, outputs=stage)
128
+ def move_to_generation(idx: int):
129
+ if idx is None:
130
+ gr.Warning("Image not selected!")
131
+ return "Selection"
132
+ else:
133
+ return "Generation"
134
+
135
+ case "Generation":
136
+ with gr.Row():
137
+ gr.Image(value=first_frame_list[idx], label="Input Image")
138
+ # gr.Video(value=metadata["video_paths"][idx], label="Ground Truth Video")
139
+ gr.Video(value=prepare_long_gt_video(idx), label="Ground Truth Video")
140
+ video = gr.Video(label="Generated Video")
141
+
142
+ with gr.Column():
143
+ guidance_scale = gr.Slider(
144
+ minimum=1,
145
+ maximum=6,
146
+ value=4,
147
+ step=0.5,
148
+ label="History Guidance Scale",
149
+ info="Without history guidance: 1.0; Recommended: 4.0",
150
+ interactive=True,
151
+ )
152
+ fps = gr.Slider(
153
+ minimum=1,
154
+ maximum=10,
155
+ value=4,
156
+ step=1,
157
+ label="FPS",
158
+ info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video",
159
+ interactive=True,
160
+ )
161
+ generate_button = gr.Button("Generate Video").click(
162
+ fn=single_image_to_long_video,
163
+ inputs=[selected_index, guidance_scale, fps],
164
+ outputs=video,
165
+ )
166
+ # def generate_video(idx: int):
167
+ # gr.Video(value=single_image_to_long_video(idx))
168
+
169
+ # Function to update the state with the selected index
170
+
171
+ # def show_warning(selection: gr.SelectData):
172
+ # gr.Warning(f"Your choice is #{selection.index}, with image: {selection.value['image']['path']}!")
173
+
174
+ # # image_gallery.select(fn=show_warning, inputs=None)
175
+
176
+ # # Show the generate button only if an image is selected
177
+ # selected_index.change(fn=lambda idx: idx is not None, inputs=selected_index, outputs=generate_button)
178
+
179
+ with gr.Tab("Any Images → Video", id="task-2"):
180
+ gr.Markdown(
181
+ """
182
+ ## Demo 1: Any Images → Video
183
+ > #### **TL;DR:** _Diffusion Forcing Transformer is a flexible model that can generate videos given variable number of context frames._
184
+ """
185
+ )
186
+ input_text_1 = gr.Textbox(
187
+ lines=2, placeholder="Enter text for Video Model 1..."
188
+ )
189
+ output_video_1 = gr.Video()
190
+ generate_button_1 = gr.Button("Generate Video")
191
+
192
+ with gr.Tab("Single Image → Extremely Long Video", id="task-3"):
193
+ gr.Markdown(
194
+ """
195
+ ## Demo 3: Single Image → Extremely Long Video
196
+ > #### **TL;DR:** _Diffusion Forcing Transformer is a flexible model that can generate videos given **variable number of context frames**._
197
+ """
198
+ )
199
+ input_text_2 = gr.Textbox(
200
+ lines=2, placeholder="Enter text for Video Model 2..."
201
+ )
202
+ output_video_2 = gr.Video()
203
+ generate_button_2 = gr.Button("Generate Video")
204
+
205
+ if __name__ == "__main__":
206
+ demo.launch()
config.yaml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ debug: False
2
+ lr: 5e-5
3
+ backbone:
4
+ name: u_vit3d_pose
5
+ channels:
6
+ - 128
7
+ - 256
8
+ - 576
9
+ - 1152
10
+ emb_channels: 1024
11
+ patch_size: 2
12
+ block_types:
13
+ - ResBlock
14
+ - ResBlock
15
+ - TransformerBlock
16
+ - TransformerBlock
17
+ block_dropouts:
18
+ - 0.0
19
+ - 0.0
20
+ - 0.1
21
+ - 0.1
22
+ num_updown_blocks:
23
+ - 3
24
+ - 3
25
+ - 6
26
+ num_mid_blocks: 20
27
+ num_heads: 9
28
+ pos_emb_type: rope
29
+ use_checkpointing:
30
+ - false
31
+ - false
32
+ - false
33
+ - true
34
+ conditioning:
35
+ dim: null
36
+ external_cond_dropout: 0.1
37
+ use_fourier_noise_embedding: true
38
+ x_shape: [3, 256, 256]
39
+ max_frames: 8
40
+ n_frames: 8
41
+ frame_skip: 1
42
+ context_frames: 1
43
+ latent:
44
+ enable: False
45
+ type: pre_sample
46
+ suffix: null
47
+ downsampling_factor: [1, 8]
48
+ num_channels: 4
49
+ data_mean: [[[0.577]], [[0.517]], [[0.461]]]
50
+ data_std: [[[0.249]], [[0.249]], [[0.268]]]
51
+ external_cond_dim: 16
52
+ external_cond_stack: False
53
+ external_cond_processing: null
54
+ compile: false
55
+ weight_decay: 0.01
56
+ optimizer_beta:
57
+ - 0.9
58
+ - 0.99
59
+ lr_scheduler:
60
+ name: constant_with_warmup
61
+ num_warmup_steps: 10000
62
+ num_training_steps: 550000
63
+ noise_level: random_independent
64
+ uniform_future:
65
+ enabled: false
66
+ fixed_context:
67
+ enabled: false
68
+ indices: null
69
+ dropout: 0
70
+ variable_context:
71
+ enabled: false
72
+ prob: 0
73
+ dropout: 0
74
+ chunk_size: -1
75
+ scheduling_matrix: full_sequence
76
+ replacement: noisy_scale
77
+ diffusion:
78
+ is_continuous: true
79
+ timesteps: 1000
80
+ beta_schedule: cosine_simple_diffusion
81
+ schedule_fn_kwargs:
82
+ shift: 1.0
83
+ shifted: 0.125
84
+ interpolated: false
85
+ use_causal_mask: false
86
+ clip_noise: 20.0
87
+ objective: pred_v
88
+ loss_weighting:
89
+ strategy: sigmoid
90
+ snr_clip: 5.0
91
+ cum_snr_decay: 0.9
92
+ sigmoid_bias: -1.0
93
+ sampling_timesteps: 50
94
+ ddim_sampling_eta: 0.0
95
+ reconstruction_guidance: 0.0
96
+ training_schedule:
97
+ name: cosine
98
+ shift: 0.125
99
+ precond_scale: 0.125
100
+ vae:
101
+ pretrained_path: null
102
+ pretrained_kwargs: {}
103
+ use_fp16: true
104
+ batch_size: 2
105
+ checkpoint:
106
+ reset_optimizer: false
107
+ strict: true
108
+ tasks:
109
+ prediction:
110
+ enabled: true
111
+ history_guidance:
112
+ name: stabilized_vanilla
113
+ guidance_scale: 4.0
114
+ stabilization_level: 0.02
115
+ visualize: False
116
+ keyframe_density: null
117
+ sliding_context_len: null
118
+ interpolation:
119
+ enabled: false
120
+ history_guidance:
121
+ name: vanilla
122
+ guidance_scale: 1.5
123
+ visualize: False
124
+ max_batch_size: 1
125
+ logging:
126
+ deterministic: null
127
+ loss_freq: 100
128
+ grad_norm_freq: 100
129
+ max_num_videos: 256
130
+ n_metrics_frames: null
131
+ metrics: []
132
+ metrics_batch_size: 16
133
+ sanity_generation: false
134
+ raw_dir: null
135
+ camera_pose_conditioning:
136
+ normalize_by: first
137
+ bound: null
138
+ type: ray_encoding
export.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import torch
3
+ from torch import Tensor
4
+ from torchvision.io import write_video
5
+
6
+ def export_to_video(tensor: Tensor, fps: int = 10) -> str:
7
+ path = tempfile.NamedTemporaryFile(suffix=".mp4").name
8
+ write_video(path, (tensor.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8), fps=fps)
9
+ return path
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # gradio
2
+ # spaces
3
+ git+https://github.com/kwsong0113/dfot-test.git@release#egg=dfot # FIXME: change to the official repo