Spaces:
Running
on
T4
Running
on
T4
deepbeepmeep
commited on
Commit
·
d141aca
1
Parent(s):
c0c0b08
Added Preview mode and support Sky Reels v2 Diffusion Forcing
Browse files- README.md +13 -8
- requirements.txt +1 -1
- wan/__init__.py +1 -0
- wan/image2video.py +2 -2
- wan/modules/model.py +108 -23
- wan/modules/sage2_core.py +1 -1
- wan/text2video.py +5 -31
- wgp.py +0 -0
README.md
CHANGED
@@ -10,6 +10,7 @@
|
|
10 |
|
11 |
|
12 |
## 🔥 Latest News!!
|
|
|
13 |
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
|
14 |
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
|
15 |
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
@@ -302,18 +303,22 @@ There is also a guide that describes the various combination of hints (https://g
|
|
302 |
|
303 |
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
|
304 |
|
305 |
-
### VACE Slidig Window
|
306 |
-
With this mode (that works for the moment only with Vace) you can merge mutiple Videos to form a very long video (up to 1 min).
|
307 |
|
308 |
-
|
309 |
|
310 |
-
|
311 |
-
- *overlap frames* : the first frames ofa new window are filled with last frames of the previous window in order to ensure continuity between the two windows
|
312 |
-
- *discard last frames* : quite often the last frames of a window have a worse quality. You decide here how many ending frames of a new window should be dropped.
|
313 |
|
314 |
-
|
315 |
|
316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
### Command line parameters for Gradio Server
|
319 |
--i2v : launch the image to video generator\
|
|
|
10 |
|
11 |
|
12 |
## 🔥 Latest News!!
|
13 |
+
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Windo siding section below)
|
14 |
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
|
15 |
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
|
16 |
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
|
|
303 |
|
304 |
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
|
305 |
|
306 |
+
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
|
307 |
+
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
|
308 |
|
309 |
+
When combined with Vace this feature can use the same control video to generate the full Video that results from concatenining the different windows. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
|
310 |
|
311 |
+
When combined with Sky Reels V2, you can extend an existing video indefinetely.
|
|
|
|
|
312 |
|
313 |
+
Sliding Windows are turned on by default and are triggered as soon as you try to generate a Video longer than the Window Size. You can go in the Advanced Settings Tab *Sliding Window* to set this Window Size. You can make the Video even longer during the generation process by adding one more Window to generate each time you click "Extend the Video Sample, Please !" button.
|
314 |
|
315 |
+
Although the window duration is set by the *Sliding Window Size* form field, the actual number of frames generated by each iteration will be less, because of the *overlap frames* and *discard last frames*:
|
316 |
+
- *overlap frames* : the first frames of a new window are filled with last frames of the previous window in order to ensure continuity between the two windows
|
317 |
+
- *discard last frames* : quite often (Vace model Only) the last frames of a window have a worse quality. You can decide here how many ending frames of a new window should be dropped.
|
318 |
+
s
|
319 |
+
Number of Generated Frames = [Number of Windows - 1] * ([Window Size] - [Overlap Frames] - [Discard Last Frames]) + [Window Size]
|
320 |
+
|
321 |
+
Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
|
322 |
|
323 |
### Command line parameters for Gradio Server
|
324 |
--i2v : launch the image to video generator\
|
requirements.txt
CHANGED
@@ -12,7 +12,7 @@ ftfy
|
|
12 |
dashscope
|
13 |
imageio-ffmpeg
|
14 |
# flash_attn
|
15 |
-
gradio
|
16 |
numpy>=1.23.5,<2
|
17 |
einops
|
18 |
moviepy==1.0.3
|
|
|
12 |
dashscope
|
13 |
imageio-ffmpeg
|
14 |
# flash_attn
|
15 |
+
gradio==5.23.0
|
16 |
numpy>=1.23.5,<2
|
17 |
einops
|
18 |
moviepy==1.0.3
|
wan/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
from . import configs, distributed, modules
|
2 |
from .image2video import WanI2V
|
3 |
from .text2video import WanT2V
|
|
|
|
1 |
from . import configs, distributed, modules
|
2 |
from .image2video import WanI2V
|
3 |
from .text2video import WanT2V
|
4 |
+
from .diffusion_forcing import DTT2V
|
wan/image2video.py
CHANGED
@@ -352,7 +352,7 @@ class WanI2V:
|
|
352 |
|
353 |
# self.model.to(self.device)
|
354 |
if callback != None:
|
355 |
-
callback(-1, True)
|
356 |
|
357 |
for i, t in enumerate(tqdm(timesteps)):
|
358 |
offload.set_step_no_for_lora(self.model, i)
|
@@ -426,7 +426,7 @@ class WanI2V:
|
|
426 |
del timestep
|
427 |
|
428 |
if callback is not None:
|
429 |
-
callback(i, False)
|
430 |
|
431 |
|
432 |
x0 = [latent.to(self.device, dtype=self.dtype)]
|
|
|
352 |
|
353 |
# self.model.to(self.device)
|
354 |
if callback != None:
|
355 |
+
callback(-1, None, True)
|
356 |
|
357 |
for i, t in enumerate(tqdm(timesteps)):
|
358 |
offload.set_step_no_for_lora(self.model, i)
|
|
|
426 |
del timestep
|
427 |
|
428 |
if callback is not None:
|
429 |
+
callback(i, latent, False)
|
430 |
|
431 |
|
432 |
x0 = [latent.to(self.device, dtype=self.dtype)]
|
wan/modules/model.py
CHANGED
@@ -10,6 +10,7 @@ import numpy as np
|
|
10 |
from typing import Union,Optional
|
11 |
from mmgp import offload
|
12 |
from .attention import pay_attention
|
|
|
13 |
|
14 |
__all__ = ['WanModel']
|
15 |
|
@@ -27,6 +28,10 @@ def sinusoidal_embedding_1d(dim, position):
|
|
27 |
return x
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
def identify_k( b: float, d: int, N: int):
|
@@ -167,7 +172,7 @@ class WanSelfAttention(nn.Module):
|
|
167 |
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
168 |
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
169 |
|
170 |
-
def forward(self, xlist, grid_sizes, freqs):
|
171 |
r"""
|
172 |
Args:
|
173 |
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
@@ -190,12 +195,44 @@ class WanSelfAttention(nn.Module):
|
|
190 |
del x
|
191 |
qklist = [q,k]
|
192 |
del q,k
|
|
|
193 |
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
194 |
qkv_list = [q,k,v]
|
195 |
del q,k,v
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
# output
|
200 |
x = x.flatten(2)
|
201 |
x = self.o(x)
|
@@ -360,7 +397,8 @@ class WanAttentionBlock(nn.Module):
|
|
360 |
context,
|
361 |
hints= None,
|
362 |
context_scale=1.0,
|
363 |
-
cam_emb= None
|
|
|
364 |
):
|
365 |
r"""
|
366 |
Args:
|
@@ -381,13 +419,14 @@ class WanAttentionBlock(nn.Module):
|
|
381 |
hint = self.vace(hints, x, **kwargs)
|
382 |
else:
|
383 |
hint = self.vace(hints, None, **kwargs)
|
384 |
-
|
385 |
e = (self.modulation + e).chunk(6, dim=1)
|
386 |
-
|
387 |
# self-attention
|
388 |
x_mod = self.norm1(x)
|
|
|
389 |
x_mod *= 1 + e[1]
|
390 |
x_mod += e[0]
|
|
|
391 |
if cam_emb != None:
|
392 |
cam_emb = self.cam_encoder(cam_emb)
|
393 |
cam_emb = cam_emb.repeat(1, 2, 1)
|
@@ -397,12 +436,13 @@ class WanAttentionBlock(nn.Module):
|
|
397 |
|
398 |
xlist = [x_mod]
|
399 |
del x_mod
|
400 |
-
y = self.self_attn( xlist, grid_sizes, freqs)
|
401 |
if cam_emb != None:
|
402 |
y = self.projector(y)
|
403 |
-
# x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
|
404 |
|
|
|
405 |
x.addcmul_(y, e[2])
|
|
|
406 |
del y
|
407 |
y = self.norm3(x)
|
408 |
ylist= [y]
|
@@ -410,8 +450,10 @@ class WanAttentionBlock(nn.Module):
|
|
410 |
x += self.cross_attn(ylist, context)
|
411 |
y = self.norm2(x)
|
412 |
|
|
|
413 |
y *= 1 + e[4]
|
414 |
y += e[3]
|
|
|
415 |
|
416 |
ffn = self.ffn[0]
|
417 |
gelu = self.ffn[1]
|
@@ -428,7 +470,9 @@ class WanAttentionBlock(nn.Module):
|
|
428 |
del mlp_chunk
|
429 |
y = y.view(y_shape)
|
430 |
|
|
|
431 |
x.addcmul_(y, e[5])
|
|
|
432 |
|
433 |
if hint is not None:
|
434 |
if context_scale == 1:
|
@@ -500,10 +544,14 @@ class Head(nn.Module):
|
|
500 |
"""
|
501 |
# assert e.dtype == torch.float32
|
502 |
dtype = x.dtype
|
|
|
|
|
503 |
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
504 |
x = self.norm(x).to(dtype)
|
|
|
505 |
x *= (1 + e[1])
|
506 |
x += e[0]
|
|
|
507 |
x = self.head(x)
|
508 |
return x
|
509 |
|
@@ -552,7 +600,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
552 |
qk_norm=True,
|
553 |
cross_attn_norm=True,
|
554 |
eps=1e-6,
|
555 |
-
recammaster = False
|
|
|
556 |
):
|
557 |
r"""
|
558 |
Initialize the diffusion model backbone.
|
@@ -609,6 +658,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
609 |
self.qk_norm = qk_norm
|
610 |
self.cross_attn_norm = cross_attn_norm
|
611 |
self.eps = eps
|
|
|
|
|
|
|
|
|
612 |
|
613 |
# embeddings
|
614 |
self.patch_embedding = nn.Conv3d(
|
@@ -617,6 +670,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
617 |
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
618 |
nn.Linear(dim, dim))
|
619 |
|
|
|
|
|
|
|
|
|
620 |
self.time_embedding = nn.Sequential(
|
621 |
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
622 |
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
@@ -678,12 +735,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
678 |
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
679 |
|
680 |
|
681 |
-
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
682 |
rescale_func = np.poly1d(self.coefficients)
|
683 |
e_list = []
|
684 |
for t in timesteps:
|
685 |
t = torch.stack([t])
|
686 |
-
|
|
|
687 |
|
688 |
best_threshold = 0.01
|
689 |
best_diff = 1000
|
@@ -695,16 +753,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
695 |
nb_steps = 0
|
696 |
diff = 1000
|
697 |
for i, t in enumerate(timesteps):
|
698 |
-
skip = False
|
699 |
if not (i<=start_step or i== len(timesteps)):
|
700 |
-
accumulated_rel_l1_distance += rescale_func(((e_list[i]-
|
701 |
-
# self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
|
702 |
-
|
703 |
if accumulated_rel_l1_distance < threshold:
|
704 |
skip = True
|
705 |
else:
|
706 |
accumulated_rel_l1_distance = 0
|
707 |
-
previous_modulated_input = e_list[i]
|
708 |
if not skip:
|
709 |
nb_steps += 1
|
710 |
signed_diff = target_nb_steps - nb_steps
|
@@ -739,6 +794,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
739 |
slg_layers=None,
|
740 |
callback = None,
|
741 |
cam_emb: torch.Tensor = None,
|
|
|
|
|
|
|
742 |
):
|
743 |
|
744 |
if self.model_type == 'i2v':
|
@@ -752,26 +810,53 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
752 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
753 |
|
754 |
# embeddings
|
755 |
-
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
756 |
# grid_sizes = torch.stack(
|
757 |
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
758 |
|
759 |
grid_sizes = [ list(u.shape[2:]) for u in x]
|
760 |
embed_sizes = grid_sizes[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
761 |
|
762 |
offload.shared_state["embed_sizes"] = embed_sizes
|
763 |
offload.shared_state["step_no"] = current_step
|
764 |
offload.shared_state["max_steps"] = max_steps
|
765 |
|
766 |
-
|
767 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
768 |
x = x[0]
|
769 |
|
770 |
-
|
|
|
|
|
|
|
|
|
|
|
771 |
e = self.time_embedding(
|
772 |
-
sinusoidal_embedding_1d(self.freq_dim, t))
|
|
|
773 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
774 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
775 |
# context
|
776 |
context = self.text_embedding(
|
777 |
torch.stack([
|
@@ -833,7 +918,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
833 |
self.accumulated_rel_l1_distance = 0
|
834 |
else:
|
835 |
rescale_func = np.poly1d(self.coefficients)
|
836 |
-
self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
837 |
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
838 |
should_calc = False
|
839 |
self.teacache_skipped_steps += 1
|
@@ -858,7 +943,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
858 |
for block_idx, block in enumerate(self.blocks):
|
859 |
offload.shared_state["layer"] = block_idx
|
860 |
if callback != None:
|
861 |
-
callback(-1, False, True)
|
862 |
if pipeline._interrupt:
|
863 |
if joint_pass:
|
864 |
return None, None
|
|
|
10 |
from typing import Union,Optional
|
11 |
from mmgp import offload
|
12 |
from .attention import pay_attention
|
13 |
+
from torch.backends.cuda import sdp_kernel
|
14 |
|
15 |
__all__ = ['WanModel']
|
16 |
|
|
|
28 |
return x
|
29 |
|
30 |
|
31 |
+
def reshape_latent(latent, latent_frames):
|
32 |
+
if latent_frames == latent.shape[0]:
|
33 |
+
return latent
|
34 |
+
return latent.reshape(latent_frames, -1, latent.shape[-1] )
|
35 |
|
36 |
|
37 |
def identify_k( b: float, d: int, N: int):
|
|
|
172 |
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
173 |
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
174 |
|
175 |
+
def forward(self, xlist, grid_sizes, freqs, block_mask = None):
|
176 |
r"""
|
177 |
Args:
|
178 |
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
|
|
195 |
del x
|
196 |
qklist = [q,k]
|
197 |
del q,k
|
198 |
+
|
199 |
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
200 |
qkv_list = [q,k,v]
|
201 |
del q,k,v
|
202 |
+
if block_mask == None:
|
203 |
+
x = pay_attention(
|
204 |
+
qkv_list,
|
205 |
+
window_size=self.window_size)
|
206 |
+
else:
|
207 |
+
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
208 |
+
x = (
|
209 |
+
torch.nn.functional.scaled_dot_product_attention(
|
210 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
211 |
+
)
|
212 |
+
.transpose(1, 2)
|
213 |
+
.contiguous()
|
214 |
+
)
|
215 |
+
|
216 |
+
# if not self._flag_ar_attention:
|
217 |
+
# q = rope_apply(q, grid_sizes, freqs)
|
218 |
+
# k = rope_apply(k, grid_sizes, freqs)
|
219 |
+
# x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
|
220 |
+
# else:
|
221 |
+
# q = rope_apply(q, grid_sizes, freqs)
|
222 |
+
# k = rope_apply(k, grid_sizes, freqs)
|
223 |
+
# q = q.to(torch.bfloat16)
|
224 |
+
# k = k.to(torch.bfloat16)
|
225 |
+
# v = v.to(torch.bfloat16)
|
226 |
+
|
227 |
+
# with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
228 |
+
# x = (
|
229 |
+
# torch.nn.functional.scaled_dot_product_attention(
|
230 |
+
# q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
231 |
+
# )
|
232 |
+
# .transpose(1, 2)
|
233 |
+
# .contiguous()
|
234 |
+
# )
|
235 |
+
|
236 |
# output
|
237 |
x = x.flatten(2)
|
238 |
x = self.o(x)
|
|
|
397 |
context,
|
398 |
hints= None,
|
399 |
context_scale=1.0,
|
400 |
+
cam_emb= None,
|
401 |
+
block_mask = None
|
402 |
):
|
403 |
r"""
|
404 |
Args:
|
|
|
419 |
hint = self.vace(hints, x, **kwargs)
|
420 |
else:
|
421 |
hint = self.vace(hints, None, **kwargs)
|
422 |
+
latent_frames = e.shape[0]
|
423 |
e = (self.modulation + e).chunk(6, dim=1)
|
|
|
424 |
# self-attention
|
425 |
x_mod = self.norm1(x)
|
426 |
+
x_mod = reshape_latent(x_mod , latent_frames)
|
427 |
x_mod *= 1 + e[1]
|
428 |
x_mod += e[0]
|
429 |
+
x_mod = reshape_latent(x_mod , 1)
|
430 |
if cam_emb != None:
|
431 |
cam_emb = self.cam_encoder(cam_emb)
|
432 |
cam_emb = cam_emb.repeat(1, 2, 1)
|
|
|
436 |
|
437 |
xlist = [x_mod]
|
438 |
del x_mod
|
439 |
+
y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
|
440 |
if cam_emb != None:
|
441 |
y = self.projector(y)
|
|
|
442 |
|
443 |
+
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
444 |
x.addcmul_(y, e[2])
|
445 |
+
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
446 |
del y
|
447 |
y = self.norm3(x)
|
448 |
ylist= [y]
|
|
|
450 |
x += self.cross_attn(ylist, context)
|
451 |
y = self.norm2(x)
|
452 |
|
453 |
+
y = reshape_latent(y , latent_frames)
|
454 |
y *= 1 + e[4]
|
455 |
y += e[3]
|
456 |
+
y = reshape_latent(y , 1)
|
457 |
|
458 |
ffn = self.ffn[0]
|
459 |
gelu = self.ffn[1]
|
|
|
470 |
del mlp_chunk
|
471 |
y = y.view(y_shape)
|
472 |
|
473 |
+
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
474 |
x.addcmul_(y, e[5])
|
475 |
+
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
476 |
|
477 |
if hint is not None:
|
478 |
if context_scale == 1:
|
|
|
544 |
"""
|
545 |
# assert e.dtype == torch.float32
|
546 |
dtype = x.dtype
|
547 |
+
|
548 |
+
latent_frames = e.shape[0]
|
549 |
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
550 |
x = self.norm(x).to(dtype)
|
551 |
+
x = reshape_latent(x , latent_frames)
|
552 |
x *= (1 + e[1])
|
553 |
x += e[0]
|
554 |
+
x = reshape_latent(x , 1)
|
555 |
x = self.head(x)
|
556 |
return x
|
557 |
|
|
|
600 |
qk_norm=True,
|
601 |
cross_attn_norm=True,
|
602 |
eps=1e-6,
|
603 |
+
recammaster = False,
|
604 |
+
inject_sample_info = False,
|
605 |
):
|
606 |
r"""
|
607 |
Initialize the diffusion model backbone.
|
|
|
658 |
self.qk_norm = qk_norm
|
659 |
self.cross_attn_norm = cross_attn_norm
|
660 |
self.eps = eps
|
661 |
+
self.num_frame_per_block = 1
|
662 |
+
self.flag_causal_attention = False
|
663 |
+
self.block_mask = None
|
664 |
+
self.inject_sample_info = inject_sample_info
|
665 |
|
666 |
# embeddings
|
667 |
self.patch_embedding = nn.Conv3d(
|
|
|
670 |
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
671 |
nn.Linear(dim, dim))
|
672 |
|
673 |
+
if inject_sample_info:
|
674 |
+
self.fps_embedding = nn.Embedding(2, dim)
|
675 |
+
self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
|
676 |
+
|
677 |
self.time_embedding = nn.Sequential(
|
678 |
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
679 |
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
|
|
735 |
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
736 |
|
737 |
|
738 |
+
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
739 |
rescale_func = np.poly1d(self.coefficients)
|
740 |
e_list = []
|
741 |
for t in timesteps:
|
742 |
t = torch.stack([t])
|
743 |
+
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
|
744 |
+
e_list.append(time_emb)
|
745 |
|
746 |
best_threshold = 0.01
|
747 |
best_diff = 1000
|
|
|
753 |
nb_steps = 0
|
754 |
diff = 1000
|
755 |
for i, t in enumerate(timesteps):
|
756 |
+
skip = False
|
757 |
if not (i<=start_step or i== len(timesteps)):
|
758 |
+
accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
|
|
|
|
|
759 |
if accumulated_rel_l1_distance < threshold:
|
760 |
skip = True
|
761 |
else:
|
762 |
accumulated_rel_l1_distance = 0
|
|
|
763 |
if not skip:
|
764 |
nb_steps += 1
|
765 |
signed_diff = target_nb_steps - nb_steps
|
|
|
794 |
slg_layers=None,
|
795 |
callback = None,
|
796 |
cam_emb: torch.Tensor = None,
|
797 |
+
fps = None,
|
798 |
+
causal_block_size = 1,
|
799 |
+
causal_attention = False,
|
800 |
):
|
801 |
|
802 |
if self.model_type == 'i2v':
|
|
|
810 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
811 |
|
812 |
# embeddings
|
813 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
814 |
# grid_sizes = torch.stack(
|
815 |
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
816 |
|
817 |
grid_sizes = [ list(u.shape[2:]) for u in x]
|
818 |
embed_sizes = grid_sizes[0]
|
819 |
+
if causal_attention : #causal_block_size > 0:
|
820 |
+
frame_num = embed_sizes[0]
|
821 |
+
height = embed_sizes[1]
|
822 |
+
width = embed_sizes[2]
|
823 |
+
block_num = frame_num // causal_block_size
|
824 |
+
range_tensor = torch.arange(block_num).view(-1, 1)
|
825 |
+
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
|
826 |
+
causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
|
827 |
+
causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x[0].device)
|
828 |
+
causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
|
829 |
+
causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
|
830 |
+
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
831 |
+
del causal_mask
|
832 |
|
833 |
offload.shared_state["embed_sizes"] = embed_sizes
|
834 |
offload.shared_state["step_no"] = current_step
|
835 |
offload.shared_state["max_steps"] = max_steps
|
836 |
|
|
|
837 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
838 |
x = x[0]
|
839 |
|
840 |
+
if t.dim() == 2:
|
841 |
+
b, f = t.shape
|
842 |
+
_flag_df = True
|
843 |
+
else:
|
844 |
+
_flag_df = False
|
845 |
+
|
846 |
e = self.time_embedding(
|
847 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
|
848 |
+
) # b, dim
|
849 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
850 |
|
851 |
+
if self.inject_sample_info:
|
852 |
+
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
853 |
+
|
854 |
+
fps_emb = self.fps_embedding(fps).float()
|
855 |
+
if _flag_df:
|
856 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
857 |
+
else:
|
858 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
|
859 |
+
|
860 |
# context
|
861 |
context = self.text_embedding(
|
862 |
torch.stack([
|
|
|
918 |
self.accumulated_rel_l1_distance = 0
|
919 |
else:
|
920 |
rescale_func = np.poly1d(self.coefficients)
|
921 |
+
self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
|
922 |
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
923 |
should_calc = False
|
924 |
self.teacache_skipped_steps += 1
|
|
|
943 |
for block_idx, block in enumerate(self.blocks):
|
944 |
offload.shared_state["layer"] = block_idx
|
945 |
if callback != None:
|
946 |
+
callback(-1, None, False, True)
|
947 |
if pipeline._interrupt:
|
948 |
if joint_pass:
|
949 |
return None, None
|
wan/modules/sage2_core.py
CHANGED
@@ -1075,13 +1075,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
|
|
1075 |
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
|
1076 |
|
1077 |
q_size = q.size()
|
|
|
1078 |
q_device = q.device
|
1079 |
del q,k
|
1080 |
|
1081 |
|
1082 |
# pad v to multiple of 128
|
1083 |
# TODO: modify per_channel_fp8 kernel to handle this
|
1084 |
-
kv_len = k.size(seq_dim)
|
1085 |
v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
|
1086 |
if v_pad_len > 0:
|
1087 |
if tensor_layout == "HND":
|
|
|
1075 |
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
|
1076 |
|
1077 |
q_size = q.size()
|
1078 |
+
kv_len = k.size(seq_dim)
|
1079 |
q_device = q.device
|
1080 |
del q,k
|
1081 |
|
1082 |
|
1083 |
# pad v to multiple of 128
|
1084 |
# TODO: modify per_channel_fp8 kernel to handle this
|
|
|
1085 |
v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
|
1086 |
if v_pad_len > 0:
|
1087 |
if tensor_layout == "HND":
|
wan/text2video.py
CHANGED
@@ -49,40 +49,14 @@ class WanT2V:
|
|
49 |
config,
|
50 |
checkpoint_dir,
|
51 |
rank=0,
|
52 |
-
t5_fsdp=False,
|
53 |
-
dit_fsdp=False,
|
54 |
-
use_usp=False,
|
55 |
-
t5_cpu=False,
|
56 |
model_filename = None,
|
57 |
text_encoder_filename = None,
|
58 |
quantizeTransformer = False,
|
59 |
dtype = torch.bfloat16
|
60 |
):
|
61 |
-
r"""
|
62 |
-
Initializes the Wan text-to-video generation model components.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
config (EasyDict):
|
66 |
-
Object containing model parameters initialized from config.py
|
67 |
-
checkpoint_dir (`str`):
|
68 |
-
Path to directory containing model checkpoints
|
69 |
-
device_id (`int`, *optional*, defaults to 0):
|
70 |
-
Id of target GPU device
|
71 |
-
rank (`int`, *optional*, defaults to 0):
|
72 |
-
Process rank for distributed training
|
73 |
-
t5_fsdp (`bool`, *optional*, defaults to False):
|
74 |
-
Enable FSDP sharding for T5 model
|
75 |
-
dit_fsdp (`bool`, *optional*, defaults to False):
|
76 |
-
Enable FSDP sharding for DiT model
|
77 |
-
use_usp (`bool`, *optional*, defaults to False):
|
78 |
-
Enable distribution strategy of USP.
|
79 |
-
t5_cpu (`bool`, *optional*, defaults to False):
|
80 |
-
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
81 |
-
"""
|
82 |
self.device = torch.device(f"cuda")
|
83 |
self.config = config
|
84 |
self.rank = rank
|
85 |
-
self.t5_cpu = t5_cpu
|
86 |
self.dtype = dtype
|
87 |
self.num_train_timesteps = config.num_train_timesteps
|
88 |
self.param_dtype = config.param_dtype
|
@@ -419,9 +393,9 @@ class WanT2V:
|
|
419 |
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
420 |
else:
|
421 |
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
422 |
-
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self}
|
423 |
-
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self}
|
424 |
-
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self}
|
425 |
|
426 |
if target_camera != None:
|
427 |
recam_dict = {'cam_emb': cam_emb}
|
@@ -438,7 +412,7 @@ class WanT2V:
|
|
438 |
if self.model.enable_teacache:
|
439 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
440 |
if callback != None:
|
441 |
-
callback(-1, True)
|
442 |
for i, t in enumerate(tqdm(timesteps)):
|
443 |
if target_camera != None:
|
444 |
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
|
@@ -494,7 +468,7 @@ class WanT2V:
|
|
494 |
del temp_x0
|
495 |
|
496 |
if callback is not None:
|
497 |
-
callback(i, False)
|
498 |
|
499 |
x0 = latents
|
500 |
|
|
|
49 |
config,
|
50 |
checkpoint_dir,
|
51 |
rank=0,
|
|
|
|
|
|
|
|
|
52 |
model_filename = None,
|
53 |
text_encoder_filename = None,
|
54 |
quantizeTransformer = False,
|
55 |
dtype = torch.bfloat16
|
56 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
self.device = torch.device(f"cuda")
|
58 |
self.config = config
|
59 |
self.rank = rank
|
|
|
60 |
self.dtype = dtype
|
61 |
self.num_train_timesteps = config.num_train_timesteps
|
62 |
self.param_dtype = config.param_dtype
|
|
|
393 |
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
394 |
else:
|
395 |
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
396 |
+
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
|
397 |
+
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
|
398 |
+
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
|
399 |
|
400 |
if target_camera != None:
|
401 |
recam_dict = {'cam_emb': cam_emb}
|
|
|
412 |
if self.model.enable_teacache:
|
413 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
414 |
if callback != None:
|
415 |
+
callback(-1, None, True)
|
416 |
for i, t in enumerate(tqdm(timesteps)):
|
417 |
if target_camera != None:
|
418 |
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
|
|
|
468 |
del temp_x0
|
469 |
|
470 |
if callback is not None:
|
471 |
+
callback(i, latents[0], False)
|
472 |
|
473 |
x0 = latents
|
474 |
|
wgp.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|