Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
5359939
1
Parent(s):
eb1feee
finish demo
Browse files- app.py +549 -49
- camera_pose.py +94 -0
- history_guidance.py +24 -0
app.py
CHANGED
@@ -6,19 +6,20 @@ import gradio as gr
|
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
from torchvision.datasets.utils import download_and_extract_archive
|
9 |
-
from
|
10 |
from omegaconf import OmegaConf
|
11 |
from algorithms.dfot import DFoTVideoPose
|
12 |
-
from
|
13 |
from utils.ckpt_utils import download_pretrained
|
14 |
-
from utils.huggingface_utils import download_from_hf
|
15 |
from datasets.video.utils.io import read_video
|
16 |
-
from datasets.video import RealEstate10KAdvancedVideoDataset
|
17 |
from export import export_to_video, export_to_gif, export_images_to_gif
|
|
|
|
|
18 |
|
19 |
DATASET_URL = "https://huggingface.co/kiwhansong/DFoT/resolve/main/datasets/RealEstate10K_Tiny.tar.gz"
|
20 |
DATASET_DIR = Path("data/real-estate-10k-tiny")
|
21 |
-
LONG_LENGTH =
|
|
|
22 |
|
23 |
if not DATASET_DIR.exists():
|
24 |
DATASET_DIR.mkdir(parents=True)
|
@@ -69,8 +70,8 @@ dfot.to("cuda")
|
|
69 |
|
70 |
def prepare_long_gt_video(idx: int):
|
71 |
video = video_list[idx]
|
72 |
-
indices = torch.linspace(0, video.size(0) - 1,
|
73 |
-
return export_to_video(video[indices], fps=
|
74 |
|
75 |
|
76 |
def prepare_short_gt_video(idx: int):
|
@@ -104,7 +105,7 @@ def single_image_to_long_video(
|
|
104 |
xs = video[indices].unsqueeze(0).to("cuda")
|
105 |
conditions = poses[indices].unsqueeze(0).to("cuda")
|
106 |
dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale
|
107 |
-
dfot.cfg.tasks.prediction.keyframe_density =
|
108 |
# dfot.cfg.tasks.interpolation.history_guidance.guidance_scale = guidance_scale
|
109 |
gen_video = dfot._unnormalize_x(
|
110 |
dfot._predict_videos(
|
@@ -151,6 +152,228 @@ def any_images_to_short_video(
|
|
151 |
return video_to_gif_and_images([image for image in gen_video], list(range(8)))
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
# Create the Gradio Blocks
|
155 |
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
156 |
gr.HTML(
|
@@ -160,6 +383,21 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
160 |
font-size: 16px !important;
|
161 |
font-weight: bold;
|
162 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
</style>
|
164 |
"""
|
165 |
)
|
@@ -169,14 +407,29 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
169 |
"### Official Interactive Demo for [_History-guided Video Diffusion_](todo)"
|
170 |
)
|
171 |
with gr.Row():
|
172 |
-
gr.Button(value="🌐 Website", link="https://boyuan.space/history-guidance")
|
173 |
-
gr.Button(value="📄 Paper", link="https://arxiv.org/abs/2502.06764")
|
174 |
gr.Button(
|
175 |
-
value="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
link="https://github.com/kwsong0113/diffusion-forcing-transformer",
|
|
|
|
|
177 |
)
|
178 |
gr.Button(
|
179 |
-
value="
|
|
|
|
|
|
|
180 |
)
|
181 |
|
182 |
with gr.Accordion("Troubleshooting: Not Working or Too Slow?", open=False):
|
@@ -187,7 +440,6 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
187 |
"""
|
188 |
)
|
189 |
|
190 |
-
|
191 |
with gr.Tab("Any # of Images → Short Video", id="task-1"):
|
192 |
gr.Markdown(
|
193 |
"""
|
@@ -225,7 +477,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
225 |
def update_selection(selection: gr.SelectData):
|
226 |
return selection.index
|
227 |
|
228 |
-
demo1_scene_select_button = gr.Button("Select Scene")
|
229 |
|
230 |
@demo1_scene_select_button.click(
|
231 |
inputs=demo1_selected_scene_index, outputs=demo1_stage
|
@@ -257,7 +509,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
257 |
choices=[(f"t={i}", i) for i in range(8)],
|
258 |
value=[],
|
259 |
)
|
260 |
-
demo1_image_select_button = gr.Button("Select Input Images")
|
261 |
|
262 |
@demo1_image_select_button.click(
|
263 |
inputs=[demo1_selector],
|
@@ -304,7 +556,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
304 |
info="Without history guidance: 1.0; Recommended: 4.0",
|
305 |
interactive=True,
|
306 |
)
|
307 |
-
gr.Button("Generate Video").click(
|
308 |
fn=any_images_to_short_video,
|
309 |
inputs=[
|
310 |
demo1_selected_scene_index,
|
@@ -316,9 +568,9 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
316 |
|
317 |
with gr.Tab("Single Image → Long Video", id="task-2"):
|
318 |
gr.Markdown(
|
319 |
-
"""
|
320 |
-
## Demo 2: Single Image → Long
|
321 |
-
> #### _Diffusion Forcing Transformer, with History Guidance,
|
322 |
"""
|
323 |
)
|
324 |
|
@@ -344,7 +596,7 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
344 |
def update_selection(selection: gr.SelectData):
|
345 |
return selection.index
|
346 |
|
347 |
-
demo2_select_button = gr.Button("Select Input Image")
|
348 |
|
349 |
@demo2_select_button.click(
|
350 |
inputs=demo2_selected_index, outputs=demo2_stage
|
@@ -369,49 +621,297 @@ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
|
369 |
label="Ground Truth Video",
|
370 |
width=256,
|
371 |
height=256,
|
|
|
|
|
372 |
)
|
373 |
demo2_video = gr.Video(
|
374 |
-
label="Generated Video",
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
)
|
376 |
|
377 |
-
|
378 |
-
|
379 |
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
)
|
389 |
-
|
|
|
|
|
390 |
minimum=2,
|
391 |
maximum=10,
|
392 |
-
value=
|
393 |
step=1,
|
394 |
-
label="
|
395 |
-
info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video",
|
396 |
interactive=True,
|
397 |
)
|
398 |
-
gr.Button("
|
399 |
-
fn=
|
400 |
-
inputs=[
|
401 |
-
|
402 |
-
|
403 |
-
|
|
|
|
|
|
|
404 |
],
|
405 |
-
outputs=demo2_video,
|
406 |
)
|
407 |
|
408 |
-
with gr.Tab("Single Image → Extremely Long Video", id="task-3"):
|
409 |
-
gr.Markdown(
|
410 |
-
"""
|
411 |
-
## Demo 3: Single Image → Extremely Long Video
|
412 |
-
> #### _TODO._
|
413 |
-
"""
|
414 |
-
)
|
415 |
|
416 |
if __name__ == "__main__":
|
417 |
demo.launch()
|
|
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
from torchvision.datasets.utils import download_and_extract_archive
|
9 |
+
from einops import repeat
|
10 |
from omegaconf import OmegaConf
|
11 |
from algorithms.dfot import DFoTVideoPose
|
12 |
+
from history_guidance import HistoryGuidance
|
13 |
from utils.ckpt_utils import download_pretrained
|
|
|
14 |
from datasets.video.utils.io import read_video
|
|
|
15 |
from export import export_to_video, export_to_gif, export_images_to_gif
|
16 |
+
from camera_pose import extend_poses, CameraPose
|
17 |
+
from scipy.spatial.transform import Rotation, Slerp
|
18 |
|
19 |
DATASET_URL = "https://huggingface.co/kiwhansong/DFoT/resolve/main/datasets/RealEstate10K_Tiny.tar.gz"
|
20 |
DATASET_DIR = Path("data/real-estate-10k-tiny")
|
21 |
+
LONG_LENGTH = 10 # seconds
|
22 |
+
NAVIGATION_FPS = 3
|
23 |
|
24 |
if not DATASET_DIR.exists():
|
25 |
DATASET_DIR.mkdir(parents=True)
|
|
|
70 |
|
71 |
def prepare_long_gt_video(idx: int):
|
72 |
video = video_list[idx]
|
73 |
+
indices = torch.linspace(0, video.size(0) - 1, 200, dtype=torch.long)
|
74 |
+
return export_to_video(video[indices], fps=200 // LONG_LENGTH)
|
75 |
|
76 |
|
77 |
def prepare_short_gt_video(idx: int):
|
|
|
105 |
xs = video[indices].unsqueeze(0).to("cuda")
|
106 |
conditions = poses[indices].unsqueeze(0).to("cuda")
|
107 |
dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale
|
108 |
+
dfot.cfg.tasks.prediction.keyframe_density = 12 / (fps * LONG_LENGTH)
|
109 |
# dfot.cfg.tasks.interpolation.history_guidance.guidance_scale = guidance_scale
|
110 |
gen_video = dfot._unnormalize_x(
|
111 |
dfot._predict_videos(
|
|
|
152 |
return video_to_gif_and_images([image for image in gen_video], list(range(8)))
|
153 |
|
154 |
|
155 |
+
class CustomProgressBar:
|
156 |
+
def __init__(self, pbar):
|
157 |
+
self.pbar = pbar
|
158 |
+
|
159 |
+
def set_postfix(self, **kwargs):
|
160 |
+
pass
|
161 |
+
|
162 |
+
def __getattr__(self, attr):
|
163 |
+
return getattr(self.pbar, attr)
|
164 |
+
|
165 |
+
|
166 |
+
@torch.autocast("cuda")
|
167 |
+
@torch.no_grad()
|
168 |
+
def navigate_video(
|
169 |
+
video: torch.Tensor,
|
170 |
+
poses: torch.Tensor,
|
171 |
+
x_angle: float,
|
172 |
+
y_angle: float,
|
173 |
+
distance: float,
|
174 |
+
):
|
175 |
+
n_context_frames = min(len(video), 4)
|
176 |
+
n_prediction_frames = 8 - n_context_frames
|
177 |
+
pbar = CustomProgressBar(
|
178 |
+
gr.Progress(track_tqdm=True).tqdm(
|
179 |
+
iterable=None,
|
180 |
+
desc=f"Predicting next {n_prediction_frames} frames",
|
181 |
+
total=dfot.sampling_timesteps,
|
182 |
+
)
|
183 |
+
)
|
184 |
+
xs = dfot._normalize_x(video.clone().unsqueeze(0).to("cuda"))
|
185 |
+
conditions = poses.clone().unsqueeze(0).to("cuda")
|
186 |
+
conditions = extend_poses(
|
187 |
+
conditions,
|
188 |
+
n=n_prediction_frames,
|
189 |
+
x_angle=x_angle,
|
190 |
+
y_angle=y_angle,
|
191 |
+
distance=distance,
|
192 |
+
)
|
193 |
+
context_mask = (
|
194 |
+
torch.cat(
|
195 |
+
[
|
196 |
+
torch.ones(1, n_context_frames) * (1 if n_context_frames == 1 else 2),
|
197 |
+
torch.zeros(1, n_prediction_frames),
|
198 |
+
],
|
199 |
+
dim=-1,
|
200 |
+
)
|
201 |
+
.long()
|
202 |
+
.to("cuda")
|
203 |
+
)
|
204 |
+
next_video = (
|
205 |
+
dfot._unnormalize_x(
|
206 |
+
dfot._sample_sequence(
|
207 |
+
batch_size=1,
|
208 |
+
context=torch.cat(
|
209 |
+
[
|
210 |
+
xs[:, -n_context_frames:],
|
211 |
+
torch.zeros(
|
212 |
+
1,
|
213 |
+
n_prediction_frames,
|
214 |
+
*xs.shape[2:],
|
215 |
+
device=xs.device,
|
216 |
+
dtype=xs.dtype,
|
217 |
+
),
|
218 |
+
],
|
219 |
+
dim=1,
|
220 |
+
),
|
221 |
+
context_mask=context_mask,
|
222 |
+
conditions=conditions[:, -8:],
|
223 |
+
history_guidance=HistoryGuidance.smart(
|
224 |
+
x_angle=x_angle,
|
225 |
+
y_angle=y_angle,
|
226 |
+
distance=distance,
|
227 |
+
visualize=False,
|
228 |
+
),
|
229 |
+
pbar=pbar,
|
230 |
+
)[0]
|
231 |
+
)[0][n_context_frames:]
|
232 |
+
.detach()
|
233 |
+
.cpu()
|
234 |
+
)
|
235 |
+
gen_video = torch.cat([video, next_video], dim=0)
|
236 |
+
poses = conditions[0]
|
237 |
+
|
238 |
+
images = (gen_video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
|
239 |
+
|
240 |
+
return (
|
241 |
+
gen_video,
|
242 |
+
poses,
|
243 |
+
images[-1],
|
244 |
+
export_to_video(gen_video, fps=NAVIGATION_FPS),
|
245 |
+
[(image, f"t={i}") for i, image in enumerate(images)],
|
246 |
+
)
|
247 |
+
|
248 |
+
def undo_navigation(
|
249 |
+
video: torch.Tensor,
|
250 |
+
poses: torch.Tensor,
|
251 |
+
):
|
252 |
+
if len(video) >= 8:
|
253 |
+
video = video[:-4]
|
254 |
+
poses = poses[:-4]
|
255 |
+
else:
|
256 |
+
gr.Warning("You have no moves left to undo!")
|
257 |
+
images = (video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
|
258 |
+
return (
|
259 |
+
video,
|
260 |
+
poses,
|
261 |
+
images[-1],
|
262 |
+
export_to_video(video, fps=NAVIGATION_FPS),
|
263 |
+
[(image, f"t={i}") for i, image in enumerate(images)],
|
264 |
+
)
|
265 |
+
|
266 |
+
def _interpolate_conditions(conditions, indices):
|
267 |
+
"""
|
268 |
+
Interpolate conditions to fill out missing frames
|
269 |
+
|
270 |
+
Aegs:
|
271 |
+
conditions (Tensor): conditions (B, T, C)
|
272 |
+
indices (Tensor): indices of keyframes (T')
|
273 |
+
"""
|
274 |
+
assert indices[0].item() == 0
|
275 |
+
assert indices[-1].item() == conditions.shape[1] - 1
|
276 |
+
|
277 |
+
indices = indices.cpu().numpy()
|
278 |
+
batch_size, n_tokens, _ = conditions.shape
|
279 |
+
t = np.linspace(0, n_tokens - 1, n_tokens)
|
280 |
+
|
281 |
+
key_conditions = conditions[:, indices]
|
282 |
+
poses = CameraPose.from_vectors(key_conditions)
|
283 |
+
extrinsics = poses.extrinsics().cpu().numpy()
|
284 |
+
ps = extrinsics[..., :3, 3]
|
285 |
+
rs = extrinsics[..., :3, :3].reshape(batch_size, -1, 3, 3)
|
286 |
+
|
287 |
+
interp_extrinsics = np.zeros((batch_size, n_tokens, 3, 4))
|
288 |
+
for i in range(batch_size):
|
289 |
+
slerp = Slerp(indices, Rotation.from_matrix(rs[i]))
|
290 |
+
interp_extrinsics[i, :, :3, :3] = slerp(t).as_matrix()
|
291 |
+
for j in range(3):
|
292 |
+
interp_extrinsics[i, :, j, 3] = np.interp(t, indices, ps[i, :, j])
|
293 |
+
interp_extrinsics = torch.from_numpy(interp_extrinsics.astype(np.float32))
|
294 |
+
interp_extrinsics = interp_extrinsics.to(conditions.device).flatten(2)
|
295 |
+
conditions = repeat(key_conditions[:, 0, :4], "b c -> b t c", t=n_tokens)
|
296 |
+
conditions = torch.cat([conditions.clone(), interp_extrinsics], dim=-1)
|
297 |
+
|
298 |
+
return conditions
|
299 |
+
|
300 |
+
@spaces.GPU(duration=300)
|
301 |
+
@torch.autocast("cuda")
|
302 |
+
@torch.no_grad()
|
303 |
+
def _interpolate_between(
|
304 |
+
xs: torch.Tensor,
|
305 |
+
conditions: torch.Tensor,
|
306 |
+
interpolation_factor: int,
|
307 |
+
progress=gr.Progress(track_tqdm=True),
|
308 |
+
):
|
309 |
+
l = xs.shape[1]
|
310 |
+
final_l = (l - 1) * interpolation_factor + 1
|
311 |
+
x_shape = xs.shape[2:]
|
312 |
+
context = torch.zeros(
|
313 |
+
(
|
314 |
+
1,
|
315 |
+
final_l,
|
316 |
+
*x_shape,
|
317 |
+
),
|
318 |
+
device=xs.device,
|
319 |
+
dtype=xs.dtype,
|
320 |
+
)
|
321 |
+
long_conditions = torch.zeros(
|
322 |
+
(1, final_l, *conditions.shape[2:]),
|
323 |
+
device=conditions.device,
|
324 |
+
dtype=conditions.dtype,
|
325 |
+
)
|
326 |
+
context_mask = torch.zeros(
|
327 |
+
(1, final_l),
|
328 |
+
device=xs.device,
|
329 |
+
dtype=torch.bool,
|
330 |
+
)
|
331 |
+
context_indices = torch.arange(
|
332 |
+
0, final_l, interpolation_factor, device=conditions.device
|
333 |
+
)
|
334 |
+
context[:, context_indices] = xs
|
335 |
+
long_conditions[:, context_indices] = conditions
|
336 |
+
context_mask[:, ::interpolation_factor] = True
|
337 |
+
long_conditions = _interpolate_conditions(
|
338 |
+
long_conditions,
|
339 |
+
context_indices,
|
340 |
+
)
|
341 |
+
|
342 |
+
xs = dfot._interpolate_videos(
|
343 |
+
context,
|
344 |
+
context_mask,
|
345 |
+
conditions=long_conditions,
|
346 |
+
)
|
347 |
+
return xs, long_conditions
|
348 |
+
|
349 |
+
def smooth_navigation(
|
350 |
+
video: torch.Tensor,
|
351 |
+
poses: torch.Tensor,
|
352 |
+
interpolation_factor: int,
|
353 |
+
progress=gr.Progress(track_tqdm=True),
|
354 |
+
):
|
355 |
+
if len(video) < 8:
|
356 |
+
gr.Warning("Navigate first before applying temporal super-resolution!")
|
357 |
+
else:
|
358 |
+
video, poses = _interpolate_between(
|
359 |
+
dfot._normalize_x(video.clone().unsqueeze(0).to("cuda")),
|
360 |
+
poses.clone().unsqueeze(0).to("cuda"),
|
361 |
+
interpolation_factor,
|
362 |
+
)
|
363 |
+
video = dfot._unnormalize_x(video)[0].detach().cpu()
|
364 |
+
poses = poses[0]
|
365 |
+
images = (video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy()
|
366 |
+
return (
|
367 |
+
video,
|
368 |
+
poses,
|
369 |
+
images[-1],
|
370 |
+
export_to_video(video, fps=NAVIGATION_FPS * interpolation_factor),
|
371 |
+
[(image, f"t={i}") for i, image in enumerate(images)],
|
372 |
+
)
|
373 |
+
|
374 |
+
|
375 |
+
|
376 |
+
|
377 |
# Create the Gradio Blocks
|
378 |
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo:
|
379 |
gr.HTML(
|
|
|
383 |
font-size: 16px !important;
|
384 |
font-weight: bold;
|
385 |
}
|
386 |
+
#header-button .button-icon {
|
387 |
+
margin-right: 8px;
|
388 |
+
}
|
389 |
+
#basic-controls {
|
390 |
+
column-gap: 0px;
|
391 |
+
}
|
392 |
+
#basic-controls button {
|
393 |
+
border: 1px solid #e4e4e7;
|
394 |
+
}
|
395 |
+
#basic-controls-tab {
|
396 |
+
padding: 0px;
|
397 |
+
}
|
398 |
+
#advanced-controls-tab {
|
399 |
+
padding: 0px;
|
400 |
+
}
|
401 |
</style>
|
402 |
"""
|
403 |
)
|
|
|
407 |
"### Official Interactive Demo for [_History-guided Video Diffusion_](todo)"
|
408 |
)
|
409 |
with gr.Row():
|
|
|
|
|
410 |
gr.Button(
|
411 |
+
value="Website",
|
412 |
+
link="https://boyuan.space/history-guidance",
|
413 |
+
icon="https://simpleicons.org/icons/googlechrome.svg",
|
414 |
+
elem_id="header-button",
|
415 |
+
)
|
416 |
+
gr.Button(
|
417 |
+
value="Paper",
|
418 |
+
link="https://arxiv.org/abs/2502.06764",
|
419 |
+
icon="https://simpleicons.org/icons/arxiv.svg",
|
420 |
+
elem_id="header-button",
|
421 |
+
)
|
422 |
+
gr.Button(
|
423 |
+
value="Code",
|
424 |
link="https://github.com/kwsong0113/diffusion-forcing-transformer",
|
425 |
+
icon="https://simpleicons.org/icons/github.svg",
|
426 |
+
elem_id="header-button",
|
427 |
)
|
428 |
gr.Button(
|
429 |
+
value="Pretrained Models",
|
430 |
+
link="https://huggingface.co/kiwhansong/DFoT",
|
431 |
+
icon="https://simpleicons.org/icons/huggingface.svg",
|
432 |
+
elem_id="header-button",
|
433 |
)
|
434 |
|
435 |
with gr.Accordion("Troubleshooting: Not Working or Too Slow?", open=False):
|
|
|
440 |
"""
|
441 |
)
|
442 |
|
|
|
443 |
with gr.Tab("Any # of Images → Short Video", id="task-1"):
|
444 |
gr.Markdown(
|
445 |
"""
|
|
|
477 |
def update_selection(selection: gr.SelectData):
|
478 |
return selection.index
|
479 |
|
480 |
+
demo1_scene_select_button = gr.Button("Select Scene", variant="primary")
|
481 |
|
482 |
@demo1_scene_select_button.click(
|
483 |
inputs=demo1_selected_scene_index, outputs=demo1_stage
|
|
|
509 |
choices=[(f"t={i}", i) for i in range(8)],
|
510 |
value=[],
|
511 |
)
|
512 |
+
demo1_image_select_button = gr.Button("Select Input Images", variant="primary")
|
513 |
|
514 |
@demo1_image_select_button.click(
|
515 |
inputs=[demo1_selector],
|
|
|
556 |
info="Without history guidance: 1.0; Recommended: 4.0",
|
557 |
interactive=True,
|
558 |
)
|
559 |
+
gr.Button("Generate Video", variant="primary").click(
|
560 |
fn=any_images_to_short_video,
|
561 |
inputs=[
|
562 |
demo1_selected_scene_index,
|
|
|
568 |
|
569 |
with gr.Tab("Single Image → Long Video", id="task-2"):
|
570 |
gr.Markdown(
|
571 |
+
f"""
|
572 |
+
## Demo 2: Single Image → Long {LONG_LENGTH}-second Video
|
573 |
+
> #### _Diffusion Forcing Transformer, with History Guidance, generates long videos via sliding window rollouts and temporal super-resolution._
|
574 |
"""
|
575 |
)
|
576 |
|
|
|
596 |
def update_selection(selection: gr.SelectData):
|
597 |
return selection.index
|
598 |
|
599 |
+
demo2_select_button = gr.Button("Select Input Image", variant="primary")
|
600 |
|
601 |
@demo2_select_button.click(
|
602 |
inputs=demo2_selected_index, outputs=demo2_stage
|
|
|
621 |
label="Ground Truth Video",
|
622 |
width=256,
|
623 |
height=256,
|
624 |
+
autoplay=True,
|
625 |
+
loop=True,
|
626 |
)
|
627 |
demo2_video = gr.Video(
|
628 |
+
label="Generated Video",
|
629 |
+
width=256,
|
630 |
+
height=256,
|
631 |
+
autoplay=True,
|
632 |
+
loop=True,
|
633 |
+
show_share_button=True,
|
634 |
+
show_download_button=True,
|
635 |
)
|
636 |
|
637 |
+
with gr.Sidebar():
|
638 |
+
gr.Markdown("### Sampling Parameters")
|
639 |
|
640 |
+
demo2_guidance_scale = gr.Slider(
|
641 |
+
minimum=1,
|
642 |
+
maximum=6,
|
643 |
+
value=4,
|
644 |
+
step=0.5,
|
645 |
+
label="History Guidance Scale",
|
646 |
+
info="Without history guidance: 1.0; Recommended: 4.0",
|
647 |
+
interactive=True,
|
648 |
+
)
|
649 |
+
demo2_fps = gr.Slider(
|
650 |
+
minimum=4,
|
651 |
+
maximum=20,
|
652 |
+
value=8,
|
653 |
+
step=1,
|
654 |
+
label="FPS",
|
655 |
+
info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video",
|
656 |
+
interactive=True,
|
657 |
+
)
|
658 |
+
gr.Button("Generate Video", variant="primary").click(
|
659 |
+
fn=single_image_to_long_video,
|
660 |
+
inputs=[
|
661 |
+
demo2_selected_index,
|
662 |
+
demo2_guidance_scale,
|
663 |
+
demo2_fps,
|
664 |
+
],
|
665 |
+
outputs=demo2_video,
|
666 |
+
)
|
667 |
+
|
668 |
+
with gr.Tab("Single Image → Endless Video Navigation", id="task-3"):
|
669 |
+
gr.Markdown(
|
670 |
+
"""
|
671 |
+
## Demo 3: Single Image → Extremely Long Video _(Navigate with Your Camera Movements!)_
|
672 |
+
> #### _History Guidance significantly improves quality and temporal consistency, enabling stable rollouts for extremely long videos._
|
673 |
+
"""
|
674 |
+
)
|
675 |
+
|
676 |
+
demo3_stage = gr.State(value="Selection")
|
677 |
+
demo3_selected_index = gr.State(value=None)
|
678 |
+
demo3_current_video = gr.State(value=None)
|
679 |
+
demo3_current_poses = gr.State(value=None)
|
680 |
+
|
681 |
+
@gr.render(inputs=[demo3_stage, demo3_selected_index])
|
682 |
+
def render_stage(s, idx):
|
683 |
+
match s:
|
684 |
+
case "Selection":
|
685 |
+
with gr.Group():
|
686 |
+
demo3_image_gallery = gr.Gallery(
|
687 |
+
height=300,
|
688 |
+
value=first_frame_list,
|
689 |
+
label="Select an Image to Start Navigation",
|
690 |
+
columns=[8],
|
691 |
+
selected_index=idx,
|
692 |
+
)
|
693 |
+
|
694 |
+
@demo3_image_gallery.select(
|
695 |
+
inputs=None, outputs=demo3_selected_index
|
696 |
+
)
|
697 |
+
def update_selection(selection: gr.SelectData):
|
698 |
+
return selection.index
|
699 |
+
|
700 |
+
demo3_select_button = gr.Button("Select Input Image", variant="primary")
|
701 |
+
|
702 |
+
@demo3_select_button.click(
|
703 |
+
inputs=demo3_selected_index,
|
704 |
+
outputs=[
|
705 |
+
demo3_stage,
|
706 |
+
demo3_current_video,
|
707 |
+
demo3_current_poses,
|
708 |
+
],
|
709 |
+
)
|
710 |
+
def move_to_generation(idx: int):
|
711 |
+
if idx is None:
|
712 |
+
gr.Warning("Image not selected!")
|
713 |
+
return "Selection", None, None
|
714 |
+
else:
|
715 |
+
return (
|
716 |
+
"Generation",
|
717 |
+
video_list[idx][:1],
|
718 |
+
poses_list[idx][:1],
|
719 |
+
)
|
720 |
+
|
721 |
+
case "Generation":
|
722 |
+
with gr.Row():
|
723 |
+
demo3_current_view = gr.Image(
|
724 |
+
value=first_frame_list[idx],
|
725 |
+
label="Current View",
|
726 |
+
width=256,
|
727 |
+
height=256,
|
728 |
+
)
|
729 |
+
demo3_video = gr.Video(
|
730 |
+
label="Generated Video",
|
731 |
+
width=256,
|
732 |
+
height=256,
|
733 |
+
autoplay=True,
|
734 |
+
loop=True,
|
735 |
+
show_share_button=True,
|
736 |
+
show_download_button=True,
|
737 |
+
)
|
738 |
+
|
739 |
+
demo3_generated_gallery = gr.Gallery(
|
740 |
+
value=[],
|
741 |
+
label="Generated Frames",
|
742 |
+
columns=[8],
|
743 |
+
)
|
744 |
+
|
745 |
+
with gr.Sidebar():
|
746 |
+
gr.Markdown(
|
747 |
+
"""
|
748 |
+
### Let's Navigate!
|
749 |
+
**The model will predict the next few frames based on your camera movements. Repeat the process to navigate through the scene.** The most suitable history guidance scheme will be automatically selected based on your camera movements.
|
750 |
+
"""
|
751 |
+
)
|
752 |
+
with gr.Tab("Basic", elem_id="basic-controls-tab"):
|
753 |
+
with gr.Group():
|
754 |
+
gr.Markdown("_**Select a direction to move:**_")
|
755 |
+
with gr.Row(elem_id="basic-controls"):
|
756 |
+
gr.Button("↰-60°\nTurn", size="sm", min_width=0, variant="primary").click(
|
757 |
+
fn=partial(
|
758 |
+
navigate_video,
|
759 |
+
x_angle=0,
|
760 |
+
y_angle=-60,
|
761 |
+
distance=0,
|
762 |
+
),
|
763 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
764 |
+
outputs=[
|
765 |
+
demo3_current_video,
|
766 |
+
demo3_current_poses,
|
767 |
+
demo3_current_view,
|
768 |
+
demo3_video,
|
769 |
+
demo3_generated_gallery,
|
770 |
+
],
|
771 |
+
)
|
772 |
+
|
773 |
+
gr.Button("↖-30°\nVeer", size="sm", min_width=0, variant="primary").click(
|
774 |
+
fn=partial(
|
775 |
+
navigate_video,
|
776 |
+
x_angle=0,
|
777 |
+
y_angle=-30,
|
778 |
+
distance=50,
|
779 |
+
),
|
780 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
781 |
+
outputs=[
|
782 |
+
demo3_current_video,
|
783 |
+
demo3_current_poses,
|
784 |
+
demo3_current_view,
|
785 |
+
demo3_video,
|
786 |
+
demo3_generated_gallery,
|
787 |
+
],
|
788 |
+
)
|
789 |
+
|
790 |
+
gr.Button("↑0°\nAhead", size="sm", min_width=0, variant="primary").click(
|
791 |
+
fn=partial(
|
792 |
+
navigate_video,
|
793 |
+
x_angle=0,
|
794 |
+
y_angle=0,
|
795 |
+
distance=100,
|
796 |
+
),
|
797 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
798 |
+
outputs=[
|
799 |
+
demo3_current_video,
|
800 |
+
demo3_current_poses,
|
801 |
+
demo3_current_view,
|
802 |
+
demo3_video,
|
803 |
+
demo3_generated_gallery,
|
804 |
+
],
|
805 |
+
)
|
806 |
+
gr.Button("↗30°\nVeer", size="sm", min_width=0, variant="primary").click(
|
807 |
+
fn=partial(
|
808 |
+
navigate_video,
|
809 |
+
x_angle=0,
|
810 |
+
y_angle=30,
|
811 |
+
distance=50,
|
812 |
+
),
|
813 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
814 |
+
outputs=[
|
815 |
+
demo3_current_video,
|
816 |
+
demo3_current_poses,
|
817 |
+
demo3_current_view,
|
818 |
+
demo3_video,
|
819 |
+
demo3_generated_gallery,
|
820 |
+
],
|
821 |
+
)
|
822 |
+
gr.Button("↱\n60° Turn", size="sm", min_width=0, variant="primary").click(
|
823 |
+
fn=partial(
|
824 |
+
navigate_video,
|
825 |
+
x_angle=0,
|
826 |
+
y_angle=60,
|
827 |
+
distance=0,
|
828 |
+
),
|
829 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
830 |
+
outputs=[
|
831 |
+
demo3_current_video,
|
832 |
+
demo3_current_poses,
|
833 |
+
demo3_current_view,
|
834 |
+
demo3_video,
|
835 |
+
demo3_generated_gallery,
|
836 |
+
],
|
837 |
+
)
|
838 |
+
with gr.Tab("Advanced", elem_id="advanced-controls-tab"):
|
839 |
+
with gr.Group():
|
840 |
+
gr.Markdown("_**Select angles and distance:**_")
|
841 |
+
|
842 |
+
demo3_y_angle = gr.Slider(
|
843 |
+
minimum=-90,
|
844 |
+
maximum=90,
|
845 |
+
value=0,
|
846 |
+
step=10,
|
847 |
+
label="Horizontal Angle",
|
848 |
+
interactive=True,
|
849 |
+
)
|
850 |
+
demo3_x_angle = gr.Slider(
|
851 |
+
minimum=-40,
|
852 |
+
maximum=40,
|
853 |
+
value=0,
|
854 |
+
step=10,
|
855 |
+
label="Vertical Angle",
|
856 |
+
interactive=True,
|
857 |
+
)
|
858 |
+
demo3_distance = gr.Slider(
|
859 |
+
minimum=0,
|
860 |
+
maximum=200,
|
861 |
+
value=100,
|
862 |
+
step=10,
|
863 |
+
label="Distance",
|
864 |
+
interactive=True,
|
865 |
+
)
|
866 |
+
|
867 |
+
gr.Button("Generate Next Move", variant="primary").click(
|
868 |
+
fn=partial(
|
869 |
+
navigate_video,
|
870 |
+
),
|
871 |
+
inputs=[demo3_current_video, demo3_current_poses, demo3_x_angle, demo3_y_angle, demo3_distance],
|
872 |
+
outputs=[
|
873 |
+
demo3_current_video,
|
874 |
+
demo3_current_poses,
|
875 |
+
demo3_current_view,
|
876 |
+
demo3_video,
|
877 |
+
demo3_generated_gallery,
|
878 |
+
],
|
879 |
+
)
|
880 |
+
with gr.Group():
|
881 |
+
gr.Markdown("_You can always undo your last move:_")
|
882 |
+
gr.Button("Undo Last Move", variant="huggingface").click(
|
883 |
+
fn=undo_navigation,
|
884 |
+
inputs=[demo3_current_video, demo3_current_poses],
|
885 |
+
outputs=[
|
886 |
+
demo3_current_video,
|
887 |
+
demo3_current_poses,
|
888 |
+
demo3_current_view,
|
889 |
+
demo3_video,
|
890 |
+
demo3_generated_gallery,
|
891 |
+
],
|
892 |
)
|
893 |
+
with gr.Group():
|
894 |
+
gr.Markdown("_At the end, apply temporal super-resolution to obtain a smoother video:_")
|
895 |
+
demo3_interpolation_factor=gr.Slider(
|
896 |
minimum=2,
|
897 |
maximum=10,
|
898 |
+
value=2,
|
899 |
step=1,
|
900 |
+
label="Interpolation Factor",
|
|
|
901 |
interactive=True,
|
902 |
)
|
903 |
+
gr.Button("Smooth Out Video", variant="huggingface").click(
|
904 |
+
fn=smooth_navigation,
|
905 |
+
inputs=[demo3_current_video, demo3_current_poses, demo3_interpolation_factor],
|
906 |
+
outputs=[
|
907 |
+
demo3_current_video,
|
908 |
+
demo3_current_poses,
|
909 |
+
demo3_current_view,
|
910 |
+
demo3_video,
|
911 |
+
demo3_generated_gallery,
|
912 |
],
|
|
|
913 |
)
|
914 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
915 |
|
916 |
if __name__ == "__main__":
|
917 |
demo.launch()
|
camera_pose.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils.geometry_utils import CameraPose
|
3 |
+
from einops import rearrange, repeat
|
4 |
+
import math
|
5 |
+
import roma
|
6 |
+
|
7 |
+
class ControllableCameraPose(CameraPose):
|
8 |
+
def to_vectors(self) -> torch.Tensor:
|
9 |
+
"""
|
10 |
+
Returns the raw camera poses.
|
11 |
+
Returns:
|
12 |
+
torch.Tensor: The raw camera poses. Shape (B, T, 4 + 12).
|
13 |
+
"""
|
14 |
+
RT = torch.cat([self._R, rearrange(self._T, "b t i -> b t i 1")], dim=-1)
|
15 |
+
return torch.cat([self._K, rearrange(RT, "b t i j -> b t (i j)")], dim=-1)
|
16 |
+
|
17 |
+
def extend(
|
18 |
+
self,
|
19 |
+
num_frames: int,
|
20 |
+
x_angle: float = 0.0,
|
21 |
+
y_angle: float = 0.0,
|
22 |
+
distance: float = 100.0,
|
23 |
+
) -> None:
|
24 |
+
"""
|
25 |
+
Extends the camera poses.
|
26 |
+
Let's say 0 degree is the direction of the last camera pose.
|
27 |
+
Smoothly Move & rotate the camera poses in the direction of the given angle (clockwise) in a 2D plane.
|
28 |
+
Args:
|
29 |
+
num_frames (int): The number of frames to extend.
|
30 |
+
x_angle (float): The angle to extend. The angle is in degrees.
|
31 |
+
y_angle (float): The angle to extend. The angle is in degrees.
|
32 |
+
"""
|
33 |
+
MOVING_SCALE = 0.5 * distance / 100
|
34 |
+
self._normalize_by(self._R[:, -1], self._T[:, -1])
|
35 |
+
|
36 |
+
# first compute relative poses for the final n + num_frames th frame
|
37 |
+
|
38 |
+
# compute the rotation matrix for the given angle
|
39 |
+
R_final = roma.euler_to_rotmat(
|
40 |
+
convention="xyz",
|
41 |
+
angles=torch.tensor(
|
42 |
+
[-x_angle, -y_angle, 0], device=self._R.device, dtype=torch.float32
|
43 |
+
),
|
44 |
+
degrees=True,
|
45 |
+
dtype=torch.float32,
|
46 |
+
device=self._R.device,
|
47 |
+
).unsqueeze(0)
|
48 |
+
|
49 |
+
# compute the translation vector for the given angle
|
50 |
+
T_final = torch.tensor(
|
51 |
+
[
|
52 |
+
-MOVING_SCALE * num_frames * math.sin(math.radians(y_angle)),
|
53 |
+
MOVING_SCALE * num_frames * math.sin(math.radians(x_angle)),
|
54 |
+
-MOVING_SCALE * num_frames * math.cos(math.radians(y_angle)),
|
55 |
+
],
|
56 |
+
device=self._T.device,
|
57 |
+
dtype=self._T.dtype,
|
58 |
+
).unsqueeze(0)
|
59 |
+
|
60 |
+
R = torch.cat(
|
61 |
+
[self._R, repeat(R_final, "b i j -> b t i j", t=num_frames).clone()], dim=1
|
62 |
+
)
|
63 |
+
T = torch.cat(
|
64 |
+
[self._T, repeat(T_final, "b i -> b t i", t=num_frames).clone()], dim=1
|
65 |
+
)
|
66 |
+
K = torch.cat(
|
67 |
+
[self._K, repeat(self._K[:, -1], "b i -> b t i", t=num_frames).clone()],
|
68 |
+
dim=1,
|
69 |
+
)
|
70 |
+
self._R = R
|
71 |
+
self._T = T
|
72 |
+
self._K = K
|
73 |
+
# interpolate all frames btwn the last frame and the final frame
|
74 |
+
self.replace_with_interpolation(
|
75 |
+
torch.cat(
|
76 |
+
[
|
77 |
+
torch.zeros_like(self._T[:, :-num_frames, 0]),
|
78 |
+
torch.ones_like(self._T[:, -num_frames:-1, 0]),
|
79 |
+
torch.zeros_like(self._T[:, -1:, 0]),
|
80 |
+
],
|
81 |
+
dim=-1,
|
82 |
+
).bool()
|
83 |
+
)
|
84 |
+
|
85 |
+
def extend_poses(
|
86 |
+
conditions: torch.Tensor,
|
87 |
+
n: int,
|
88 |
+
x_angle: float = 0.0,
|
89 |
+
y_angle: float = 0.0,
|
90 |
+
distance: float = 0.0,
|
91 |
+
) -> torch.Tensor:
|
92 |
+
poses = ControllableCameraPose.from_vectors(conditions)
|
93 |
+
poses.extend(n, x_angle, y_angle, distance)
|
94 |
+
return poses.to_vectors()
|
history_guidance.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from algorithms.dfot.history_guidance import HistoryGuidance as _HistoryGuidance
|
2 |
+
|
3 |
+
class HistoryGuidance(_HistoryGuidance):
|
4 |
+
@classmethod
|
5 |
+
def smart(
|
6 |
+
cls,
|
7 |
+
x_angle: float,
|
8 |
+
y_angle: float,
|
9 |
+
distance: float,
|
10 |
+
visualize: bool = False,
|
11 |
+
):
|
12 |
+
if abs(x_angle) < 30 and abs(y_angle) < 30 and distance < 150:
|
13 |
+
return cls.stabilized_fractional(
|
14 |
+
guidance_scale=4.0,
|
15 |
+
stabilization_level=0.02,
|
16 |
+
freq_scale=0.4,
|
17 |
+
visualize=visualize,
|
18 |
+
)
|
19 |
+
else:
|
20 |
+
return cls.stabilized_vanilla(
|
21 |
+
guidance_scale=4.0,
|
22 |
+
stabilization_level=0.02,
|
23 |
+
visualize=visualize,
|
24 |
+
)
|