kiwhansong commited on
Commit
5359939
·
1 Parent(s): eb1feee

finish demo

Browse files
Files changed (3) hide show
  1. app.py +549 -49
  2. camera_pose.py +94 -0
  3. history_guidance.py +24 -0
app.py CHANGED
@@ -6,19 +6,20 @@ import gradio as gr
6
  import numpy as np
7
  import torch
8
  from torchvision.datasets.utils import download_and_extract_archive
9
- from PIL import Image
10
  from omegaconf import OmegaConf
11
  from algorithms.dfot import DFoTVideoPose
12
- from algorithms.dfot.history_guidance import HistoryGuidance
13
  from utils.ckpt_utils import download_pretrained
14
- from utils.huggingface_utils import download_from_hf
15
  from datasets.video.utils.io import read_video
16
- from datasets.video import RealEstate10KAdvancedVideoDataset
17
  from export import export_to_video, export_to_gif, export_images_to_gif
 
 
18
 
19
  DATASET_URL = "https://huggingface.co/kiwhansong/DFoT/resolve/main/datasets/RealEstate10K_Tiny.tar.gz"
20
  DATASET_DIR = Path("data/real-estate-10k-tiny")
21
- LONG_LENGTH = 20 # seconds
 
22
 
23
  if not DATASET_DIR.exists():
24
  DATASET_DIR.mkdir(parents=True)
@@ -69,8 +70,8 @@ dfot.to("cuda")
69
 
70
  def prepare_long_gt_video(idx: int):
71
  video = video_list[idx]
72
- indices = torch.linspace(0, video.size(0) - 1, LONG_LENGTH * 10, dtype=torch.long)
73
- return export_to_video(video[indices], fps=10)
74
 
75
 
76
  def prepare_short_gt_video(idx: int):
@@ -104,7 +105,7 @@ def single_image_to_long_video(
104
  xs = video[indices].unsqueeze(0).to("cuda")
105
  conditions = poses[indices].unsqueeze(0).to("cuda")
106
  dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale
107
- dfot.cfg.tasks.prediction.keyframe_density = 0.6 / fps
108
  # dfot.cfg.tasks.interpolation.history_guidance.guidance_scale = guidance_scale
109
  gen_video = dfot._unnormalize_x(
110
  dfot._predict_videos(
@@ -151,6 +152,228 @@ def any_images_to_short_video(
151
  return video_to_gif_and_images([image for image in gen_video], list(range(8)))
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  # Create the Gradio Blocks
155
  with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
156
  gr.HTML(
@@ -160,6 +383,21 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
160
  font-size: 16px !important;
161
  font-weight: bold;
162
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  </style>
164
  """
165
  )
@@ -169,14 +407,29 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
169
  "### Official Interactive Demo for [_History-guided Video Diffusion_](todo)"
170
  )
171
  with gr.Row():
172
- gr.Button(value="🌐 Website", link="https://boyuan.space/history-guidance")
173
- gr.Button(value="📄 Paper", link="https://arxiv.org/abs/2502.06764")
174
  gr.Button(
175
- value="💻 Code",
 
 
 
 
 
 
 
 
 
 
 
 
176
  link="https://github.com/kwsong0113/diffusion-forcing-transformer",
 
 
177
  )
178
  gr.Button(
179
- value="🤗 Pretrained Models", link="https://huggingface.co/kiwhansong/DFoT"
 
 
 
180
  )
181
 
182
  with gr.Accordion("Troubleshooting: Not Working or Too Slow?", open=False):
@@ -187,7 +440,6 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
187
  """
188
  )
189
 
190
-
191
  with gr.Tab("Any # of Images → Short Video", id="task-1"):
192
  gr.Markdown(
193
  """
@@ -225,7 +477,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
225
  def update_selection(selection: gr.SelectData):
226
  return selection.index
227
 
228
- demo1_scene_select_button = gr.Button("Select Scene")
229
 
230
  @demo1_scene_select_button.click(
231
  inputs=demo1_selected_scene_index, outputs=demo1_stage
@@ -257,7 +509,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
257
  choices=[(f"t={i}", i) for i in range(8)],
258
  value=[],
259
  )
260
- demo1_image_select_button = gr.Button("Select Input Images")
261
 
262
  @demo1_image_select_button.click(
263
  inputs=[demo1_selector],
@@ -304,7 +556,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
304
  info="Without history guidance: 1.0; Recommended: 4.0",
305
  interactive=True,
306
  )
307
- gr.Button("Generate Video").click(
308
  fn=any_images_to_short_video,
309
  inputs=[
310
  demo1_selected_scene_index,
@@ -316,9 +568,9 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
316
 
317
  with gr.Tab("Single Image → Long Video", id="task-2"):
318
  gr.Markdown(
319
- """
320
- ## Demo 2: Single Image → Long 20-second Video
321
- > #### _Diffusion Forcing Transformer, with History Guidance, can generate long videos via sliding window rollouts and temporal super-resolution._
322
  """
323
  )
324
 
@@ -344,7 +596,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
344
  def update_selection(selection: gr.SelectData):
345
  return selection.index
346
 
347
- demo2_select_button = gr.Button("Select Input Image")
348
 
349
  @demo2_select_button.click(
350
  inputs=demo2_selected_index, outputs=demo2_stage
@@ -369,49 +621,297 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
369
  label="Ground Truth Video",
370
  width=256,
371
  height=256,
 
 
372
  )
373
  demo2_video = gr.Video(
374
- label="Generated Video", width=256, height=256
 
 
 
 
 
 
375
  )
376
 
377
- with gr.Sidebar():
378
- gr.Markdown("### Sampling Parameters")
379
 
380
- demo2_guidance_scale = gr.Slider(
381
- minimum=1,
382
- maximum=6,
383
- value=4,
384
- step=0.5,
385
- label="History Guidance Scale",
386
- info="Without history guidance: 1.0; Recommended: 4.0",
387
- interactive=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  )
389
- demo2_fps = gr.Slider(
 
 
390
  minimum=2,
391
  maximum=10,
392
- value=4,
393
  step=1,
394
- label="FPS",
395
- info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video",
396
  interactive=True,
397
  )
398
- gr.Button("Generate Video").click(
399
- fn=single_image_to_long_video,
400
- inputs=[
401
- demo2_selected_index,
402
- demo2_guidance_scale,
403
- demo2_fps,
 
 
 
404
  ],
405
- outputs=demo2_video,
406
  )
407
 
408
- with gr.Tab("Single Image → Extremely Long Video", id="task-3"):
409
- gr.Markdown(
410
- """
411
- ## Demo 3: Single Image → Extremely Long Video
412
- > #### _TODO._
413
- """
414
- )
415
 
416
  if __name__ == "__main__":
417
  demo.launch()
 
6
  import numpy as np
7
  import torch
8
  from torchvision.datasets.utils import download_and_extract_archive
9
+ from einops import repeat
10
  from omegaconf import OmegaConf
11
  from algorithms.dfot import DFoTVideoPose
12
+ from history_guidance import HistoryGuidance
13
  from utils.ckpt_utils import download_pretrained
 
14
  from datasets.video.utils.io import read_video
 
15
  from export import export_to_video, export_to_gif, export_images_to_gif
16
+ from camera_pose import extend_poses, CameraPose
17
+ from scipy.spatial.transform import Rotation, Slerp
18
 
19
  DATASET_URL = "https://huggingface.co/kiwhansong/DFoT/resolve/main/datasets/RealEstate10K_Tiny.tar.gz"
20
  DATASET_DIR = Path("data/real-estate-10k-tiny")
21
+ LONG_LENGTH = 10 # seconds
22
+ NAVIGATION_FPS = 3
23
 
24
  if not DATASET_DIR.exists():
25
  DATASET_DIR.mkdir(parents=True)
 
70
 
71
  def prepare_long_gt_video(idx: int):
72
  video = video_list[idx]
73
+ indices = torch.linspace(0, video.size(0) - 1, 200, dtype=torch.long)
74
+ return export_to_video(video[indices], fps=200 // LONG_LENGTH)
75
 
76
 
77
  def prepare_short_gt_video(idx: int):
 
105
  xs = video[indices].unsqueeze(0).to("cuda")
106
  conditions = poses[indices].unsqueeze(0).to("cuda")
107
  dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale
108
+ dfot.cfg.tasks.prediction.keyframe_density = 12 / (fps * LONG_LENGTH)
109
  # dfot.cfg.tasks.interpolation.history_guidance.guidance_scale = guidance_scale
110
  gen_video = dfot._unnormalize_x(
111
  dfot._predict_videos(
 
152
  return video_to_gif_and_images([image for image in gen_video], list(range(8)))
153
 
154
 
155
+ class CustomProgressBar:
156
+ def __init__(self, pbar):
157
+ self.pbar = pbar
158
+
159
+ def set_postfix(self, **kwargs):
160
+ pass
161
+
162
+ def __getattr__(self, attr):
163
+ return getattr(self.pbar, attr)
164
+
165
+
166
+ @torch.autocast("cuda")
167
+ @torch.no_grad()
168
+ def navigate_video(
169
+ video: torch.Tensor,
170
+ poses: torch.Tensor,
171
+ x_angle: float,
172
+ y_angle: float,
173
+ distance: float,
174
+ ):
175
+ n_context_frames = min(len(video), 4)
176
+ n_prediction_frames = 8 - n_context_frames
177
+ pbar = CustomProgressBar(
178
+ gr.Progress(track_tqdm=True).tqdm(
179
+ iterable=None,
180
+ desc=f"Predicting next {n_prediction_frames} frames",
181
+ total=dfot.sampling_timesteps,
182
+ )
183
+ )
184
+ xs = dfot._normalize_x(video.clone().unsqueeze(0).to("cuda"))
185
+ conditions = poses.clone().unsqueeze(0).to("cuda")
186
+ conditions = extend_poses(
187
+ conditions,
188
+ n=n_prediction_frames,
189
+ x_angle=x_angle,
190
+ y_angle=y_angle,
191
+ distance=distance,
192
+ )
193
+ context_mask = (
194
+ torch.cat(
195
+ [
196
+ torch.ones(1, n_context_frames) * (1 if n_context_frames == 1 else 2),
197
+ torch.zeros(1, n_prediction_frames),
198
+ ],
199
+ dim=-1,
200
+ )
201
+ .long()
202
+ .to("cuda")
203
+ )
204
+ next_video = (
205
+ dfot._unnormalize_x(
206
+ dfot._sample_sequence(
207
+ batch_size=1,
208
+ context=torch.cat(
209
+ [
210
+ xs[:, -n_context_frames:],
211
+ torch.zeros(
212
+ 1,
213
+ n_prediction_frames,
214
+ *xs.shape[2:],
215
+ device=xs.device,
216
+ dtype=xs.dtype,
217
+ ),
218
+ ],
219
+ dim=1,
220
+ ),
221
+ context_mask=context_mask,
222
+ conditions=conditions[:, -8:],
223
+ history_guidance=HistoryGuidance.smart(
224
+ x_angle=x_angle,
225
+ y_angle=y_angle,
226
+ distance=distance,
227
+ visualize=False,
228
+ ),
229
+ pbar=pbar,
230
+ )[0]
231
+ )[0][n_context_frames:]
232
+ .detach()
233
+ .cpu()
234
+ )
235
+ gen_video = torch.cat([video, next_video], dim=0)
236
+ poses = conditions[0]
237
+
238
+ images = (gen_video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
239
+
240
+ return (
241
+ gen_video,
242
+ poses,
243
+ images[-1],
244
+ export_to_video(gen_video, fps=NAVIGATION_FPS),
245
+ [(image, f"t={i}") for i, image in enumerate(images)],
246
+ )
247
+
248
+ def undo_navigation(
249
+ video: torch.Tensor,
250
+ poses: torch.Tensor,
251
+ ):
252
+ if len(video) >= 8:
253
+ video = video[:-4]
254
+ poses = poses[:-4]
255
+ else:
256
+ gr.Warning("You have no moves left to undo!")
257
+ images = (video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
258
+ return (
259
+ video,
260
+ poses,
261
+ images[-1],
262
+ export_to_video(video, fps=NAVIGATION_FPS),
263
+ [(image, f"t={i}") for i, image in enumerate(images)],
264
+ )
265
+
266
+ def _interpolate_conditions(conditions, indices):
267
+ """
268
+ Interpolate conditions to fill out missing frames
269
+
270
+ Aegs:
271
+ conditions (Tensor): conditions (B, T, C)
272
+ indices (Tensor): indices of keyframes (T')
273
+ """
274
+ assert indices[0].item() == 0
275
+ assert indices[-1].item() == conditions.shape[1] - 1
276
+
277
+ indices = indices.cpu().numpy()
278
+ batch_size, n_tokens, _ = conditions.shape
279
+ t = np.linspace(0, n_tokens - 1, n_tokens)
280
+
281
+ key_conditions = conditions[:, indices]
282
+ poses = CameraPose.from_vectors(key_conditions)
283
+ extrinsics = poses.extrinsics().cpu().numpy()
284
+ ps = extrinsics[..., :3, 3]
285
+ rs = extrinsics[..., :3, :3].reshape(batch_size, -1, 3, 3)
286
+
287
+ interp_extrinsics = np.zeros((batch_size, n_tokens, 3, 4))
288
+ for i in range(batch_size):
289
+ slerp = Slerp(indices, Rotation.from_matrix(rs[i]))
290
+ interp_extrinsics[i, :, :3, :3] = slerp(t).as_matrix()
291
+ for j in range(3):
292
+ interp_extrinsics[i, :, j, 3] = np.interp(t, indices, ps[i, :, j])
293
+ interp_extrinsics = torch.from_numpy(interp_extrinsics.astype(np.float32))
294
+ interp_extrinsics = interp_extrinsics.to(conditions.device).flatten(2)
295
+ conditions = repeat(key_conditions[:, 0, :4], "b c -> b t c", t=n_tokens)
296
+ conditions = torch.cat([conditions.clone(), interp_extrinsics], dim=-1)
297
+
298
+ return conditions
299
+
300
+ @spaces.GPU(duration=300)
301
+ @torch.autocast("cuda")
302
+ @torch.no_grad()
303
+ def _interpolate_between(
304
+ xs: torch.Tensor,
305
+ conditions: torch.Tensor,
306
+ interpolation_factor: int,
307
+ progress=gr.Progress(track_tqdm=True),
308
+ ):
309
+ l = xs.shape[1]
310
+ final_l = (l - 1) * interpolation_factor + 1
311
+ x_shape = xs.shape[2:]
312
+ context = torch.zeros(
313
+ (
314
+ 1,
315
+ final_l,
316
+ *x_shape,
317
+ ),
318
+ device=xs.device,
319
+ dtype=xs.dtype,
320
+ )
321
+ long_conditions = torch.zeros(
322
+ (1, final_l, *conditions.shape[2:]),
323
+ device=conditions.device,
324
+ dtype=conditions.dtype,
325
+ )
326
+ context_mask = torch.zeros(
327
+ (1, final_l),
328
+ device=xs.device,
329
+ dtype=torch.bool,
330
+ )
331
+ context_indices = torch.arange(
332
+ 0, final_l, interpolation_factor, device=conditions.device
333
+ )
334
+ context[:, context_indices] = xs
335
+ long_conditions[:, context_indices] = conditions
336
+ context_mask[:, ::interpolation_factor] = True
337
+ long_conditions = _interpolate_conditions(
338
+ long_conditions,
339
+ context_indices,
340
+ )
341
+
342
+ xs = dfot._interpolate_videos(
343
+ context,
344
+ context_mask,
345
+ conditions=long_conditions,
346
+ )
347
+ return xs, long_conditions
348
+
349
+ def smooth_navigation(
350
+ video: torch.Tensor,
351
+ poses: torch.Tensor,
352
+ interpolation_factor: int,
353
+ progress=gr.Progress(track_tqdm=True),
354
+ ):
355
+ if len(video) < 8:
356
+ gr.Warning("Navigate first before applying temporal super-resolution!")
357
+ else:
358
+ video, poses = _interpolate_between(
359
+ dfot._normalize_x(video.clone().unsqueeze(0).to("cuda")),
360
+ poses.clone().unsqueeze(0).to("cuda"),
361
+ interpolation_factor,
362
+ )
363
+ video = dfot._unnormalize_x(video)[0].detach().cpu()
364
+ poses = poses[0]
365
+ images = (video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
366
+ return (
367
+ video,
368
+ poses,
369
+ images[-1],
370
+ export_to_video(video, fps=NAVIGATION_FPS * interpolation_factor),
371
+ [(image, f"t={i}") for i, image in enumerate(images)],
372
+ )
373
+
374
+
375
+
376
+
377
  # Create the Gradio Blocks
378
  with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
379
  gr.HTML(
 
383
  font-size: 16px !important;
384
  font-weight: bold;
385
  }
386
+ #header-button .button-icon {
387
+ margin-right: 8px;
388
+ }
389
+ #basic-controls {
390
+ column-gap: 0px;
391
+ }
392
+ #basic-controls button {
393
+ border: 1px solid #e4e4e7;
394
+ }
395
+ #basic-controls-tab {
396
+ padding: 0px;
397
+ }
398
+ #advanced-controls-tab {
399
+ padding: 0px;
400
+ }
401
  </style>
402
  """
403
  )
 
407
  "### Official Interactive Demo for [_History-guided Video Diffusion_](todo)"
408
  )
409
  with gr.Row():
 
 
410
  gr.Button(
411
+ value="Website",
412
+ link="https://boyuan.space/history-guidance",
413
+ icon="https://simpleicons.org/icons/googlechrome.svg",
414
+ elem_id="header-button",
415
+ )
416
+ gr.Button(
417
+ value="Paper",
418
+ link="https://arxiv.org/abs/2502.06764",
419
+ icon="https://simpleicons.org/icons/arxiv.svg",
420
+ elem_id="header-button",
421
+ )
422
+ gr.Button(
423
+ value="Code",
424
  link="https://github.com/kwsong0113/diffusion-forcing-transformer",
425
+ icon="https://simpleicons.org/icons/github.svg",
426
+ elem_id="header-button",
427
  )
428
  gr.Button(
429
+ value="Pretrained Models",
430
+ link="https://huggingface.co/kiwhansong/DFoT",
431
+ icon="https://simpleicons.org/icons/huggingface.svg",
432
+ elem_id="header-button",
433
  )
434
 
435
  with gr.Accordion("Troubleshooting: Not Working or Too Slow?", open=False):
 
440
  """
441
  )
442
 
 
443
  with gr.Tab("Any # of Images → Short Video", id="task-1"):
444
  gr.Markdown(
445
  """
 
477
  def update_selection(selection: gr.SelectData):
478
  return selection.index
479
 
480
+ demo1_scene_select_button = gr.Button("Select Scene", variant="primary")
481
 
482
  @demo1_scene_select_button.click(
483
  inputs=demo1_selected_scene_index, outputs=demo1_stage
 
509
  choices=[(f"t={i}", i) for i in range(8)],
510
  value=[],
511
  )
512
+ demo1_image_select_button = gr.Button("Select Input Images", variant="primary")
513
 
514
  @demo1_image_select_button.click(
515
  inputs=[demo1_selector],
 
556
  info="Without history guidance: 1.0; Recommended: 4.0",
557
  interactive=True,
558
  )
559
+ gr.Button("Generate Video", variant="primary").click(
560
  fn=any_images_to_short_video,
561
  inputs=[
562
  demo1_selected_scene_index,
 
568
 
569
  with gr.Tab("Single Image → Long Video", id="task-2"):
570
  gr.Markdown(
571
+ f"""
572
+ ## Demo 2: Single Image → Long {LONG_LENGTH}-second Video
573
+ > #### _Diffusion Forcing Transformer, with History Guidance, generates long videos via sliding window rollouts and temporal super-resolution._
574
  """
575
  )
576
 
 
596
  def update_selection(selection: gr.SelectData):
597
  return selection.index
598
 
599
+ demo2_select_button = gr.Button("Select Input Image", variant="primary")
600
 
601
  @demo2_select_button.click(
602
  inputs=demo2_selected_index, outputs=demo2_stage
 
621
  label="Ground Truth Video",
622
  width=256,
623
  height=256,
624
+ autoplay=True,
625
+ loop=True,
626
  )
627
  demo2_video = gr.Video(
628
+ label="Generated Video",
629
+ width=256,
630
+ height=256,
631
+ autoplay=True,
632
+ loop=True,
633
+ show_share_button=True,
634
+ show_download_button=True,
635
  )
636
 
637
+ with gr.Sidebar():
638
+ gr.Markdown("### Sampling Parameters")
639
 
640
+ demo2_guidance_scale = gr.Slider(
641
+ minimum=1,
642
+ maximum=6,
643
+ value=4,
644
+ step=0.5,
645
+ label="History Guidance Scale",
646
+ info="Without history guidance: 1.0; Recommended: 4.0",
647
+ interactive=True,
648
+ )
649
+ demo2_fps = gr.Slider(
650
+ minimum=4,
651
+ maximum=20,
652
+ value=8,
653
+ step=1,
654
+ label="FPS",
655
+ info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video",
656
+ interactive=True,
657
+ )
658
+ gr.Button("Generate Video", variant="primary").click(
659
+ fn=single_image_to_long_video,
660
+ inputs=[
661
+ demo2_selected_index,
662
+ demo2_guidance_scale,
663
+ demo2_fps,
664
+ ],
665
+ outputs=demo2_video,
666
+ )
667
+
668
+ with gr.Tab("Single Image → Endless Video Navigation", id="task-3"):
669
+ gr.Markdown(
670
+ """
671
+ ## Demo 3: Single Image → Extremely Long Video _(Navigate with Your Camera Movements!)_
672
+ > #### _History Guidance significantly improves quality and temporal consistency, enabling stable rollouts for extremely long videos._
673
+ """
674
+ )
675
+
676
+ demo3_stage = gr.State(value="Selection")
677
+ demo3_selected_index = gr.State(value=None)
678
+ demo3_current_video = gr.State(value=None)
679
+ demo3_current_poses = gr.State(value=None)
680
+
681
+ @gr.render(inputs=[demo3_stage, demo3_selected_index])
682
+ def render_stage(s, idx):
683
+ match s:
684
+ case "Selection":
685
+ with gr.Group():
686
+ demo3_image_gallery = gr.Gallery(
687
+ height=300,
688
+ value=first_frame_list,
689
+ label="Select an Image to Start Navigation",
690
+ columns=[8],
691
+ selected_index=idx,
692
+ )
693
+
694
+ @demo3_image_gallery.select(
695
+ inputs=None, outputs=demo3_selected_index
696
+ )
697
+ def update_selection(selection: gr.SelectData):
698
+ return selection.index
699
+
700
+ demo3_select_button = gr.Button("Select Input Image", variant="primary")
701
+
702
+ @demo3_select_button.click(
703
+ inputs=demo3_selected_index,
704
+ outputs=[
705
+ demo3_stage,
706
+ demo3_current_video,
707
+ demo3_current_poses,
708
+ ],
709
+ )
710
+ def move_to_generation(idx: int):
711
+ if idx is None:
712
+ gr.Warning("Image not selected!")
713
+ return "Selection", None, None
714
+ else:
715
+ return (
716
+ "Generation",
717
+ video_list[idx][:1],
718
+ poses_list[idx][:1],
719
+ )
720
+
721
+ case "Generation":
722
+ with gr.Row():
723
+ demo3_current_view = gr.Image(
724
+ value=first_frame_list[idx],
725
+ label="Current View",
726
+ width=256,
727
+ height=256,
728
+ )
729
+ demo3_video = gr.Video(
730
+ label="Generated Video",
731
+ width=256,
732
+ height=256,
733
+ autoplay=True,
734
+ loop=True,
735
+ show_share_button=True,
736
+ show_download_button=True,
737
+ )
738
+
739
+ demo3_generated_gallery = gr.Gallery(
740
+ value=[],
741
+ label="Generated Frames",
742
+ columns=[8],
743
+ )
744
+
745
+ with gr.Sidebar():
746
+ gr.Markdown(
747
+ """
748
+ ### Let's Navigate!
749
+ **The model will predict the next few frames based on your camera movements. Repeat the process to navigate through the scene.** The most suitable history guidance scheme will be automatically selected based on your camera movements.
750
+ """
751
+ )
752
+ with gr.Tab("Basic", elem_id="basic-controls-tab"):
753
+ with gr.Group():
754
+ gr.Markdown("_**Select a direction to move:**_")
755
+ with gr.Row(elem_id="basic-controls"):
756
+ gr.Button("↰-60°\nTurn", size="sm", min_width=0, variant="primary").click(
757
+ fn=partial(
758
+ navigate_video,
759
+ x_angle=0,
760
+ y_angle=-60,
761
+ distance=0,
762
+ ),
763
+ inputs=[demo3_current_video, demo3_current_poses],
764
+ outputs=[
765
+ demo3_current_video,
766
+ demo3_current_poses,
767
+ demo3_current_view,
768
+ demo3_video,
769
+ demo3_generated_gallery,
770
+ ],
771
+ )
772
+
773
+ gr.Button("↖-30°\nVeer", size="sm", min_width=0, variant="primary").click(
774
+ fn=partial(
775
+ navigate_video,
776
+ x_angle=0,
777
+ y_angle=-30,
778
+ distance=50,
779
+ ),
780
+ inputs=[demo3_current_video, demo3_current_poses],
781
+ outputs=[
782
+ demo3_current_video,
783
+ demo3_current_poses,
784
+ demo3_current_view,
785
+ demo3_video,
786
+ demo3_generated_gallery,
787
+ ],
788
+ )
789
+
790
+ gr.Button("↑0°\nAhead", size="sm", min_width=0, variant="primary").click(
791
+ fn=partial(
792
+ navigate_video,
793
+ x_angle=0,
794
+ y_angle=0,
795
+ distance=100,
796
+ ),
797
+ inputs=[demo3_current_video, demo3_current_poses],
798
+ outputs=[
799
+ demo3_current_video,
800
+ demo3_current_poses,
801
+ demo3_current_view,
802
+ demo3_video,
803
+ demo3_generated_gallery,
804
+ ],
805
+ )
806
+ gr.Button("↗30°\nVeer", size="sm", min_width=0, variant="primary").click(
807
+ fn=partial(
808
+ navigate_video,
809
+ x_angle=0,
810
+ y_angle=30,
811
+ distance=50,
812
+ ),
813
+ inputs=[demo3_current_video, demo3_current_poses],
814
+ outputs=[
815
+ demo3_current_video,
816
+ demo3_current_poses,
817
+ demo3_current_view,
818
+ demo3_video,
819
+ demo3_generated_gallery,
820
+ ],
821
+ )
822
+ gr.Button("↱\n60° Turn", size="sm", min_width=0, variant="primary").click(
823
+ fn=partial(
824
+ navigate_video,
825
+ x_angle=0,
826
+ y_angle=60,
827
+ distance=0,
828
+ ),
829
+ inputs=[demo3_current_video, demo3_current_poses],
830
+ outputs=[
831
+ demo3_current_video,
832
+ demo3_current_poses,
833
+ demo3_current_view,
834
+ demo3_video,
835
+ demo3_generated_gallery,
836
+ ],
837
+ )
838
+ with gr.Tab("Advanced", elem_id="advanced-controls-tab"):
839
+ with gr.Group():
840
+ gr.Markdown("_**Select angles and distance:**_")
841
+
842
+ demo3_y_angle = gr.Slider(
843
+ minimum=-90,
844
+ maximum=90,
845
+ value=0,
846
+ step=10,
847
+ label="Horizontal Angle",
848
+ interactive=True,
849
+ )
850
+ demo3_x_angle = gr.Slider(
851
+ minimum=-40,
852
+ maximum=40,
853
+ value=0,
854
+ step=10,
855
+ label="Vertical Angle",
856
+ interactive=True,
857
+ )
858
+ demo3_distance = gr.Slider(
859
+ minimum=0,
860
+ maximum=200,
861
+ value=100,
862
+ step=10,
863
+ label="Distance",
864
+ interactive=True,
865
+ )
866
+
867
+ gr.Button("Generate Next Move", variant="primary").click(
868
+ fn=partial(
869
+ navigate_video,
870
+ ),
871
+ inputs=[demo3_current_video, demo3_current_poses, demo3_x_angle, demo3_y_angle, demo3_distance],
872
+ outputs=[
873
+ demo3_current_video,
874
+ demo3_current_poses,
875
+ demo3_current_view,
876
+ demo3_video,
877
+ demo3_generated_gallery,
878
+ ],
879
+ )
880
+ with gr.Group():
881
+ gr.Markdown("_You can always undo your last move:_")
882
+ gr.Button("Undo Last Move", variant="huggingface").click(
883
+ fn=undo_navigation,
884
+ inputs=[demo3_current_video, demo3_current_poses],
885
+ outputs=[
886
+ demo3_current_video,
887
+ demo3_current_poses,
888
+ demo3_current_view,
889
+ demo3_video,
890
+ demo3_generated_gallery,
891
+ ],
892
  )
893
+ with gr.Group():
894
+ gr.Markdown("_At the end, apply temporal super-resolution to obtain a smoother video:_")
895
+ demo3_interpolation_factor=gr.Slider(
896
  minimum=2,
897
  maximum=10,
898
+ value=2,
899
  step=1,
900
+ label="Interpolation Factor",
 
901
  interactive=True,
902
  )
903
+ gr.Button("Smooth Out Video", variant="huggingface").click(
904
+ fn=smooth_navigation,
905
+ inputs=[demo3_current_video, demo3_current_poses, demo3_interpolation_factor],
906
+ outputs=[
907
+ demo3_current_video,
908
+ demo3_current_poses,
909
+ demo3_current_view,
910
+ demo3_video,
911
+ demo3_generated_gallery,
912
  ],
 
913
  )
914
 
 
 
 
 
 
 
 
915
 
916
  if __name__ == "__main__":
917
  demo.launch()
camera_pose.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.geometry_utils import CameraPose
3
+ from einops import rearrange, repeat
4
+ import math
5
+ import roma
6
+
7
+ class ControllableCameraPose(CameraPose):
8
+ def to_vectors(self) -> torch.Tensor:
9
+ """
10
+ Returns the raw camera poses.
11
+ Returns:
12
+ torch.Tensor: The raw camera poses. Shape (B, T, 4 + 12).
13
+ """
14
+ RT = torch.cat([self._R, rearrange(self._T, "b t i -> b t i 1")], dim=-1)
15
+ return torch.cat([self._K, rearrange(RT, "b t i j -> b t (i j)")], dim=-1)
16
+
17
+ def extend(
18
+ self,
19
+ num_frames: int,
20
+ x_angle: float = 0.0,
21
+ y_angle: float = 0.0,
22
+ distance: float = 100.0,
23
+ ) -> None:
24
+ """
25
+ Extends the camera poses.
26
+ Let's say 0 degree is the direction of the last camera pose.
27
+ Smoothly Move & rotate the camera poses in the direction of the given angle (clockwise) in a 2D plane.
28
+ Args:
29
+ num_frames (int): The number of frames to extend.
30
+ x_angle (float): The angle to extend. The angle is in degrees.
31
+ y_angle (float): The angle to extend. The angle is in degrees.
32
+ """
33
+ MOVING_SCALE = 0.5 * distance / 100
34
+ self._normalize_by(self._R[:, -1], self._T[:, -1])
35
+
36
+ # first compute relative poses for the final n + num_frames th frame
37
+
38
+ # compute the rotation matrix for the given angle
39
+ R_final = roma.euler_to_rotmat(
40
+ convention="xyz",
41
+ angles=torch.tensor(
42
+ [-x_angle, -y_angle, 0], device=self._R.device, dtype=torch.float32
43
+ ),
44
+ degrees=True,
45
+ dtype=torch.float32,
46
+ device=self._R.device,
47
+ ).unsqueeze(0)
48
+
49
+ # compute the translation vector for the given angle
50
+ T_final = torch.tensor(
51
+ [
52
+ -MOVING_SCALE * num_frames * math.sin(math.radians(y_angle)),
53
+ MOVING_SCALE * num_frames * math.sin(math.radians(x_angle)),
54
+ -MOVING_SCALE * num_frames * math.cos(math.radians(y_angle)),
55
+ ],
56
+ device=self._T.device,
57
+ dtype=self._T.dtype,
58
+ ).unsqueeze(0)
59
+
60
+ R = torch.cat(
61
+ [self._R, repeat(R_final, "b i j -> b t i j", t=num_frames).clone()], dim=1
62
+ )
63
+ T = torch.cat(
64
+ [self._T, repeat(T_final, "b i -> b t i", t=num_frames).clone()], dim=1
65
+ )
66
+ K = torch.cat(
67
+ [self._K, repeat(self._K[:, -1], "b i -> b t i", t=num_frames).clone()],
68
+ dim=1,
69
+ )
70
+ self._R = R
71
+ self._T = T
72
+ self._K = K
73
+ # interpolate all frames btwn the last frame and the final frame
74
+ self.replace_with_interpolation(
75
+ torch.cat(
76
+ [
77
+ torch.zeros_like(self._T[:, :-num_frames, 0]),
78
+ torch.ones_like(self._T[:, -num_frames:-1, 0]),
79
+ torch.zeros_like(self._T[:, -1:, 0]),
80
+ ],
81
+ dim=-1,
82
+ ).bool()
83
+ )
84
+
85
+ def extend_poses(
86
+ conditions: torch.Tensor,
87
+ n: int,
88
+ x_angle: float = 0.0,
89
+ y_angle: float = 0.0,
90
+ distance: float = 0.0,
91
+ ) -> torch.Tensor:
92
+ poses = ControllableCameraPose.from_vectors(conditions)
93
+ poses.extend(n, x_angle, y_angle, distance)
94
+ return poses.to_vectors()
history_guidance.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from algorithms.dfot.history_guidance import HistoryGuidance as _HistoryGuidance
2
+
3
+ class HistoryGuidance(_HistoryGuidance):
4
+ @classmethod
5
+ def smart(
6
+ cls,
7
+ x_angle: float,
8
+ y_angle: float,
9
+ distance: float,
10
+ visualize: bool = False,
11
+ ):
12
+ if abs(x_angle) < 30 and abs(y_angle) < 30 and distance < 150:
13
+ return cls.stabilized_fractional(
14
+ guidance_scale=4.0,
15
+ stabilization_level=0.02,
16
+ freq_scale=0.4,
17
+ visualize=visualize,
18
+ )
19
+ else:
20
+ return cls.stabilized_vanilla(
21
+ guidance_scale=4.0,
22
+ stabilization_level=0.02,
23
+ visualize=visualize,
24
+ )