deepbeepmeep commited on
Commit
d141aca
·
1 Parent(s): c0c0b08

Added Preview mode and support Sky Reels v2 Diffusion Forcing

Browse files
README.md CHANGED
@@ -10,6 +10,7 @@
10
 
11
 
12
  ## 🔥 Latest News!!
 
13
  * April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
14
  * April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
15
  * April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
@@ -302,18 +303,22 @@ There is also a guide that describes the various combination of hints (https://g
302
 
303
  It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
304
 
305
- ### VACE Slidig Window
306
- With this mode (that works for the moment only with Vace) you can merge mutiple Videos to form a very long video (up to 1 min). What is this very nice a about this feature is that the resulting video can be driven by the same control video. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
307
 
308
- To turn on sliding window, you need to go in the Advanced Settings Tab *Sliding Window* and set the iteration number to a number greater than 1. This number corresponds to the default number of windows. You can still increase the number during the genreation by clicking the "One More Sample, Please !" button.
309
 
310
- Each window duration will be set by the *Number of frames (16 = 1s)* form field. However the actual number of frames generated by each iteration will be less, because the *overlap frames* and *discard last frames*:
311
- - *overlap frames* : the first frames ofa new window are filled with last frames of the previous window in order to ensure continuity between the two windows
312
- - *discard last frames* : quite often the last frames of a window have a worse quality. You decide here how many ending frames of a new window should be dropped.
313
 
314
- Number of Generated = [Number of iterations] * ([Number of frames] - [Overlap Frames] - [Discard Last Frames]) + [Overlap Frames]
315
 
316
- Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
 
 
 
 
 
 
317
 
318
  ### Command line parameters for Gradio Server
319
  --i2v : launch the image to video generator\
 
10
 
11
 
12
  ## 🔥 Latest News!!
13
+ * April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Windo siding section below)
14
  * April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
15
  * April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
16
  * April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
 
303
 
304
  It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
305
 
306
+ ### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
307
+ With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
308
 
309
+ When combined with Vace this feature can use the same control video to generate the full Video that results from concatenining the different windows. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
310
 
311
+ When combined with Sky Reels V2, you can extend an existing video indefinetely.
 
 
312
 
313
+ Sliding Windows are turned on by default and are triggered as soon as you try to generate a Video longer than the Window Size. You can go in the Advanced Settings Tab *Sliding Window* to set this Window Size. You can make the Video even longer during the generation process by adding one more Window to generate each time you click "Extend the Video Sample, Please !" button.
314
 
315
+ Although the window duration is set by the *Sliding Window Size* form field, the actual number of frames generated by each iteration will be less, because of the *overlap frames* and *discard last frames*:
316
+ - *overlap frames* : the first frames of a new window are filled with last frames of the previous window in order to ensure continuity between the two windows
317
+ - *discard last frames* : quite often (Vace model Only) the last frames of a window have a worse quality. You can decide here how many ending frames of a new window should be dropped.
318
+ s
319
+ Number of Generated Frames = [Number of Windows - 1] * ([Window Size] - [Overlap Frames] - [Discard Last Frames]) + [Window Size]
320
+
321
+ Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
322
 
323
  ### Command line parameters for Gradio Server
324
  --i2v : launch the image to video generator\
requirements.txt CHANGED
@@ -12,7 +12,7 @@ ftfy
12
  dashscope
13
  imageio-ffmpeg
14
  # flash_attn
15
- gradio>=5.0.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
 
12
  dashscope
13
  imageio-ffmpeg
14
  # flash_attn
15
+ gradio==5.23.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
wan/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from . import configs, distributed, modules
2
  from .image2video import WanI2V
3
  from .text2video import WanT2V
 
 
1
  from . import configs, distributed, modules
2
  from .image2video import WanI2V
3
  from .text2video import WanT2V
4
+ from .diffusion_forcing import DTT2V
wan/image2video.py CHANGED
@@ -352,7 +352,7 @@ class WanI2V:
352
 
353
  # self.model.to(self.device)
354
  if callback != None:
355
- callback(-1, True)
356
 
357
  for i, t in enumerate(tqdm(timesteps)):
358
  offload.set_step_no_for_lora(self.model, i)
@@ -426,7 +426,7 @@ class WanI2V:
426
  del timestep
427
 
428
  if callback is not None:
429
- callback(i, False)
430
 
431
 
432
  x0 = [latent.to(self.device, dtype=self.dtype)]
 
352
 
353
  # self.model.to(self.device)
354
  if callback != None:
355
+ callback(-1, None, True)
356
 
357
  for i, t in enumerate(tqdm(timesteps)):
358
  offload.set_step_no_for_lora(self.model, i)
 
426
  del timestep
427
 
428
  if callback is not None:
429
+ callback(i, latent, False)
430
 
431
 
432
  x0 = [latent.to(self.device, dtype=self.dtype)]
wan/modules/model.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  from typing import Union,Optional
11
  from mmgp import offload
12
  from .attention import pay_attention
 
13
 
14
  __all__ = ['WanModel']
15
 
@@ -27,6 +28,10 @@ def sinusoidal_embedding_1d(dim, position):
27
  return x
28
 
29
 
 
 
 
 
30
 
31
 
32
  def identify_k( b: float, d: int, N: int):
@@ -167,7 +172,7 @@ class WanSelfAttention(nn.Module):
167
  self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
168
  self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
169
 
170
- def forward(self, xlist, grid_sizes, freqs):
171
  r"""
172
  Args:
173
  x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -190,12 +195,44 @@ class WanSelfAttention(nn.Module):
190
  del x
191
  qklist = [q,k]
192
  del q,k
 
193
  q,k = apply_rotary_emb(qklist, freqs, head_first=False)
194
  qkv_list = [q,k,v]
195
  del q,k,v
196
- x = pay_attention(
197
- qkv_list,
198
- window_size=self.window_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  # output
200
  x = x.flatten(2)
201
  x = self.o(x)
@@ -360,7 +397,8 @@ class WanAttentionBlock(nn.Module):
360
  context,
361
  hints= None,
362
  context_scale=1.0,
363
- cam_emb= None
 
364
  ):
365
  r"""
366
  Args:
@@ -381,13 +419,14 @@ class WanAttentionBlock(nn.Module):
381
  hint = self.vace(hints, x, **kwargs)
382
  else:
383
  hint = self.vace(hints, None, **kwargs)
384
-
385
  e = (self.modulation + e).chunk(6, dim=1)
386
-
387
  # self-attention
388
  x_mod = self.norm1(x)
 
389
  x_mod *= 1 + e[1]
390
  x_mod += e[0]
 
391
  if cam_emb != None:
392
  cam_emb = self.cam_encoder(cam_emb)
393
  cam_emb = cam_emb.repeat(1, 2, 1)
@@ -397,12 +436,13 @@ class WanAttentionBlock(nn.Module):
397
 
398
  xlist = [x_mod]
399
  del x_mod
400
- y = self.self_attn( xlist, grid_sizes, freqs)
401
  if cam_emb != None:
402
  y = self.projector(y)
403
- # x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
404
 
 
405
  x.addcmul_(y, e[2])
 
406
  del y
407
  y = self.norm3(x)
408
  ylist= [y]
@@ -410,8 +450,10 @@ class WanAttentionBlock(nn.Module):
410
  x += self.cross_attn(ylist, context)
411
  y = self.norm2(x)
412
 
 
413
  y *= 1 + e[4]
414
  y += e[3]
 
415
 
416
  ffn = self.ffn[0]
417
  gelu = self.ffn[1]
@@ -428,7 +470,9 @@ class WanAttentionBlock(nn.Module):
428
  del mlp_chunk
429
  y = y.view(y_shape)
430
 
 
431
  x.addcmul_(y, e[5])
 
432
 
433
  if hint is not None:
434
  if context_scale == 1:
@@ -500,10 +544,14 @@ class Head(nn.Module):
500
  """
501
  # assert e.dtype == torch.float32
502
  dtype = x.dtype
 
 
503
  e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
504
  x = self.norm(x).to(dtype)
 
505
  x *= (1 + e[1])
506
  x += e[0]
 
507
  x = self.head(x)
508
  return x
509
 
@@ -552,7 +600,8 @@ class WanModel(ModelMixin, ConfigMixin):
552
  qk_norm=True,
553
  cross_attn_norm=True,
554
  eps=1e-6,
555
- recammaster = False
 
556
  ):
557
  r"""
558
  Initialize the diffusion model backbone.
@@ -609,6 +658,10 @@ class WanModel(ModelMixin, ConfigMixin):
609
  self.qk_norm = qk_norm
610
  self.cross_attn_norm = cross_attn_norm
611
  self.eps = eps
 
 
 
 
612
 
613
  # embeddings
614
  self.patch_embedding = nn.Conv3d(
@@ -617,6 +670,10 @@ class WanModel(ModelMixin, ConfigMixin):
617
  nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
618
  nn.Linear(dim, dim))
619
 
 
 
 
 
620
  self.time_embedding = nn.Sequential(
621
  nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
622
  self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
@@ -678,12 +735,13 @@ class WanModel(ModelMixin, ConfigMixin):
678
  block.projector.bias = nn.Parameter(torch.zeros(dim))
679
 
680
 
681
- def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
682
  rescale_func = np.poly1d(self.coefficients)
683
  e_list = []
684
  for t in timesteps:
685
  t = torch.stack([t])
686
- e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)))
 
687
 
688
  best_threshold = 0.01
689
  best_diff = 1000
@@ -695,16 +753,13 @@ class WanModel(ModelMixin, ConfigMixin):
695
  nb_steps = 0
696
  diff = 1000
697
  for i, t in enumerate(timesteps):
698
- skip = False
699
  if not (i<=start_step or i== len(timesteps)):
700
- accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
701
- # self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
702
-
703
  if accumulated_rel_l1_distance < threshold:
704
  skip = True
705
  else:
706
  accumulated_rel_l1_distance = 0
707
- previous_modulated_input = e_list[i]
708
  if not skip:
709
  nb_steps += 1
710
  signed_diff = target_nb_steps - nb_steps
@@ -739,6 +794,9 @@ class WanModel(ModelMixin, ConfigMixin):
739
  slg_layers=None,
740
  callback = None,
741
  cam_emb: torch.Tensor = None,
 
 
 
742
  ):
743
 
744
  if self.model_type == 'i2v':
@@ -752,26 +810,53 @@ class WanModel(ModelMixin, ConfigMixin):
752
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
753
 
754
  # embeddings
755
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
756
  # grid_sizes = torch.stack(
757
  # [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
758
 
759
  grid_sizes = [ list(u.shape[2:]) for u in x]
760
  embed_sizes = grid_sizes[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
761
 
762
  offload.shared_state["embed_sizes"] = embed_sizes
763
  offload.shared_state["step_no"] = current_step
764
  offload.shared_state["max_steps"] = max_steps
765
 
766
-
767
  x = [u.flatten(2).transpose(1, 2) for u in x]
768
  x = x[0]
769
 
770
- # time embeddings
 
 
 
 
 
771
  e = self.time_embedding(
772
- sinusoidal_embedding_1d(self.freq_dim, t))
 
773
  e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
774
 
 
 
 
 
 
 
 
 
 
775
  # context
776
  context = self.text_embedding(
777
  torch.stack([
@@ -833,7 +918,7 @@ class WanModel(ModelMixin, ConfigMixin):
833
  self.accumulated_rel_l1_distance = 0
834
  else:
835
  rescale_func = np.poly1d(self.coefficients)
836
- self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
837
  if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
838
  should_calc = False
839
  self.teacache_skipped_steps += 1
@@ -858,7 +943,7 @@ class WanModel(ModelMixin, ConfigMixin):
858
  for block_idx, block in enumerate(self.blocks):
859
  offload.shared_state["layer"] = block_idx
860
  if callback != None:
861
- callback(-1, False, True)
862
  if pipeline._interrupt:
863
  if joint_pass:
864
  return None, None
 
10
  from typing import Union,Optional
11
  from mmgp import offload
12
  from .attention import pay_attention
13
+ from torch.backends.cuda import sdp_kernel
14
 
15
  __all__ = ['WanModel']
16
 
 
28
  return x
29
 
30
 
31
+ def reshape_latent(latent, latent_frames):
32
+ if latent_frames == latent.shape[0]:
33
+ return latent
34
+ return latent.reshape(latent_frames, -1, latent.shape[-1] )
35
 
36
 
37
  def identify_k( b: float, d: int, N: int):
 
172
  self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
173
  self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
174
 
175
+ def forward(self, xlist, grid_sizes, freqs, block_mask = None):
176
  r"""
177
  Args:
178
  x(Tensor): Shape [B, L, num_heads, C / num_heads]
 
195
  del x
196
  qklist = [q,k]
197
  del q,k
198
+
199
  q,k = apply_rotary_emb(qklist, freqs, head_first=False)
200
  qkv_list = [q,k,v]
201
  del q,k,v
202
+ if block_mask == None:
203
+ x = pay_attention(
204
+ qkv_list,
205
+ window_size=self.window_size)
206
+ else:
207
+ with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
208
+ x = (
209
+ torch.nn.functional.scaled_dot_product_attention(
210
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
211
+ )
212
+ .transpose(1, 2)
213
+ .contiguous()
214
+ )
215
+
216
+ # if not self._flag_ar_attention:
217
+ # q = rope_apply(q, grid_sizes, freqs)
218
+ # k = rope_apply(k, grid_sizes, freqs)
219
+ # x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
220
+ # else:
221
+ # q = rope_apply(q, grid_sizes, freqs)
222
+ # k = rope_apply(k, grid_sizes, freqs)
223
+ # q = q.to(torch.bfloat16)
224
+ # k = k.to(torch.bfloat16)
225
+ # v = v.to(torch.bfloat16)
226
+
227
+ # with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
228
+ # x = (
229
+ # torch.nn.functional.scaled_dot_product_attention(
230
+ # q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
231
+ # )
232
+ # .transpose(1, 2)
233
+ # .contiguous()
234
+ # )
235
+
236
  # output
237
  x = x.flatten(2)
238
  x = self.o(x)
 
397
  context,
398
  hints= None,
399
  context_scale=1.0,
400
+ cam_emb= None,
401
+ block_mask = None
402
  ):
403
  r"""
404
  Args:
 
419
  hint = self.vace(hints, x, **kwargs)
420
  else:
421
  hint = self.vace(hints, None, **kwargs)
422
+ latent_frames = e.shape[0]
423
  e = (self.modulation + e).chunk(6, dim=1)
 
424
  # self-attention
425
  x_mod = self.norm1(x)
426
+ x_mod = reshape_latent(x_mod , latent_frames)
427
  x_mod *= 1 + e[1]
428
  x_mod += e[0]
429
+ x_mod = reshape_latent(x_mod , 1)
430
  if cam_emb != None:
431
  cam_emb = self.cam_encoder(cam_emb)
432
  cam_emb = cam_emb.repeat(1, 2, 1)
 
436
 
437
  xlist = [x_mod]
438
  del x_mod
439
+ y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
440
  if cam_emb != None:
441
  y = self.projector(y)
 
442
 
443
+ x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
444
  x.addcmul_(y, e[2])
445
+ x, y = reshape_latent(x , 1), reshape_latent(y , 1)
446
  del y
447
  y = self.norm3(x)
448
  ylist= [y]
 
450
  x += self.cross_attn(ylist, context)
451
  y = self.norm2(x)
452
 
453
+ y = reshape_latent(y , latent_frames)
454
  y *= 1 + e[4]
455
  y += e[3]
456
+ y = reshape_latent(y , 1)
457
 
458
  ffn = self.ffn[0]
459
  gelu = self.ffn[1]
 
470
  del mlp_chunk
471
  y = y.view(y_shape)
472
 
473
+ x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
474
  x.addcmul_(y, e[5])
475
+ x, y = reshape_latent(x , 1), reshape_latent(y , 1)
476
 
477
  if hint is not None:
478
  if context_scale == 1:
 
544
  """
545
  # assert e.dtype == torch.float32
546
  dtype = x.dtype
547
+
548
+ latent_frames = e.shape[0]
549
  e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
550
  x = self.norm(x).to(dtype)
551
+ x = reshape_latent(x , latent_frames)
552
  x *= (1 + e[1])
553
  x += e[0]
554
+ x = reshape_latent(x , 1)
555
  x = self.head(x)
556
  return x
557
 
 
600
  qk_norm=True,
601
  cross_attn_norm=True,
602
  eps=1e-6,
603
+ recammaster = False,
604
+ inject_sample_info = False,
605
  ):
606
  r"""
607
  Initialize the diffusion model backbone.
 
658
  self.qk_norm = qk_norm
659
  self.cross_attn_norm = cross_attn_norm
660
  self.eps = eps
661
+ self.num_frame_per_block = 1
662
+ self.flag_causal_attention = False
663
+ self.block_mask = None
664
+ self.inject_sample_info = inject_sample_info
665
 
666
  # embeddings
667
  self.patch_embedding = nn.Conv3d(
 
670
  nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
671
  nn.Linear(dim, dim))
672
 
673
+ if inject_sample_info:
674
+ self.fps_embedding = nn.Embedding(2, dim)
675
+ self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
676
+
677
  self.time_embedding = nn.Sequential(
678
  nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
679
  self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
 
735
  block.projector.bias = nn.Parameter(torch.zeros(dim))
736
 
737
 
738
+ def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
739
  rescale_func = np.poly1d(self.coefficients)
740
  e_list = []
741
  for t in timesteps:
742
  t = torch.stack([t])
743
+ time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
744
+ e_list.append(time_emb)
745
 
746
  best_threshold = 0.01
747
  best_diff = 1000
 
753
  nb_steps = 0
754
  diff = 1000
755
  for i, t in enumerate(timesteps):
756
+ skip = False
757
  if not (i<=start_step or i== len(timesteps)):
758
+ accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
 
 
759
  if accumulated_rel_l1_distance < threshold:
760
  skip = True
761
  else:
762
  accumulated_rel_l1_distance = 0
 
763
  if not skip:
764
  nb_steps += 1
765
  signed_diff = target_nb_steps - nb_steps
 
794
  slg_layers=None,
795
  callback = None,
796
  cam_emb: torch.Tensor = None,
797
+ fps = None,
798
+ causal_block_size = 1,
799
+ causal_attention = False,
800
  ):
801
 
802
  if self.model_type == 'i2v':
 
810
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
811
 
812
  # embeddings
813
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
814
  # grid_sizes = torch.stack(
815
  # [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
816
 
817
  grid_sizes = [ list(u.shape[2:]) for u in x]
818
  embed_sizes = grid_sizes[0]
819
+ if causal_attention : #causal_block_size > 0:
820
+ frame_num = embed_sizes[0]
821
+ height = embed_sizes[1]
822
+ width = embed_sizes[2]
823
+ block_num = frame_num // causal_block_size
824
+ range_tensor = torch.arange(block_num).view(-1, 1)
825
+ range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
826
+ causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
827
+ causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x[0].device)
828
+ causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
829
+ causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
830
+ block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
831
+ del causal_mask
832
 
833
  offload.shared_state["embed_sizes"] = embed_sizes
834
  offload.shared_state["step_no"] = current_step
835
  offload.shared_state["max_steps"] = max_steps
836
 
 
837
  x = [u.flatten(2).transpose(1, 2) for u in x]
838
  x = x[0]
839
 
840
+ if t.dim() == 2:
841
+ b, f = t.shape
842
+ _flag_df = True
843
+ else:
844
+ _flag_df = False
845
+
846
  e = self.time_embedding(
847
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
848
+ ) # b, dim
849
  e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
850
 
851
+ if self.inject_sample_info:
852
+ fps = torch.tensor(fps, dtype=torch.long, device=device)
853
+
854
+ fps_emb = self.fps_embedding(fps).float()
855
+ if _flag_df:
856
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
857
+ else:
858
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
859
+
860
  # context
861
  context = self.text_embedding(
862
  torch.stack([
 
918
  self.accumulated_rel_l1_distance = 0
919
  else:
920
  rescale_func = np.poly1d(self.coefficients)
921
+ self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
922
  if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
923
  should_calc = False
924
  self.teacache_skipped_steps += 1
 
943
  for block_idx, block in enumerate(self.blocks):
944
  offload.shared_state["layer"] = block_idx
945
  if callback != None:
946
+ callback(-1, None, False, True)
947
  if pipeline._interrupt:
948
  if joint_pass:
949
  return None, None
wan/modules/sage2_core.py CHANGED
@@ -1075,13 +1075,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
1075
  q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
1076
 
1077
  q_size = q.size()
 
1078
  q_device = q.device
1079
  del q,k
1080
 
1081
 
1082
  # pad v to multiple of 128
1083
  # TODO: modify per_channel_fp8 kernel to handle this
1084
- kv_len = k.size(seq_dim)
1085
  v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
1086
  if v_pad_len > 0:
1087
  if tensor_layout == "HND":
 
1075
  q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
1076
 
1077
  q_size = q.size()
1078
+ kv_len = k.size(seq_dim)
1079
  q_device = q.device
1080
  del q,k
1081
 
1082
 
1083
  # pad v to multiple of 128
1084
  # TODO: modify per_channel_fp8 kernel to handle this
 
1085
  v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
1086
  if v_pad_len > 0:
1087
  if tensor_layout == "HND":
wan/text2video.py CHANGED
@@ -49,40 +49,14 @@ class WanT2V:
49
  config,
50
  checkpoint_dir,
51
  rank=0,
52
- t5_fsdp=False,
53
- dit_fsdp=False,
54
- use_usp=False,
55
- t5_cpu=False,
56
  model_filename = None,
57
  text_encoder_filename = None,
58
  quantizeTransformer = False,
59
  dtype = torch.bfloat16
60
  ):
61
- r"""
62
- Initializes the Wan text-to-video generation model components.
63
-
64
- Args:
65
- config (EasyDict):
66
- Object containing model parameters initialized from config.py
67
- checkpoint_dir (`str`):
68
- Path to directory containing model checkpoints
69
- device_id (`int`, *optional*, defaults to 0):
70
- Id of target GPU device
71
- rank (`int`, *optional*, defaults to 0):
72
- Process rank for distributed training
73
- t5_fsdp (`bool`, *optional*, defaults to False):
74
- Enable FSDP sharding for T5 model
75
- dit_fsdp (`bool`, *optional*, defaults to False):
76
- Enable FSDP sharding for DiT model
77
- use_usp (`bool`, *optional*, defaults to False):
78
- Enable distribution strategy of USP.
79
- t5_cpu (`bool`, *optional*, defaults to False):
80
- Whether to place T5 model on CPU. Only works without t5_fsdp.
81
- """
82
  self.device = torch.device(f"cuda")
83
  self.config = config
84
  self.rank = rank
85
- self.t5_cpu = t5_cpu
86
  self.dtype = dtype
87
  self.num_train_timesteps = config.num_train_timesteps
88
  self.param_dtype = config.param_dtype
@@ -419,9 +393,9 @@ class WanT2V:
419
  freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
420
  else:
421
  freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
422
- arg_c = {'context': context, 'freqs': freqs, 'pipeline': self}
423
- arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self}
424
- arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self}
425
 
426
  if target_camera != None:
427
  recam_dict = {'cam_emb': cam_emb}
@@ -438,7 +412,7 @@ class WanT2V:
438
  if self.model.enable_teacache:
439
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
440
  if callback != None:
441
- callback(-1, True)
442
  for i, t in enumerate(tqdm(timesteps)):
443
  if target_camera != None:
444
  latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
@@ -494,7 +468,7 @@ class WanT2V:
494
  del temp_x0
495
 
496
  if callback is not None:
497
- callback(i, False)
498
 
499
  x0 = latents
500
 
 
49
  config,
50
  checkpoint_dir,
51
  rank=0,
 
 
 
 
52
  model_filename = None,
53
  text_encoder_filename = None,
54
  quantizeTransformer = False,
55
  dtype = torch.bfloat16
56
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  self.device = torch.device(f"cuda")
58
  self.config = config
59
  self.rank = rank
 
60
  self.dtype = dtype
61
  self.num_train_timesteps = config.num_train_timesteps
62
  self.param_dtype = config.param_dtype
 
393
  freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
394
  else:
395
  freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
396
+ arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
397
+ arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
398
+ arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
399
 
400
  if target_camera != None:
401
  recam_dict = {'cam_emb': cam_emb}
 
412
  if self.model.enable_teacache:
413
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
414
  if callback != None:
415
+ callback(-1, None, True)
416
  for i, t in enumerate(tqdm(timesteps)):
417
  if target_camera != None:
418
  latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
 
468
  del temp_x0
469
 
470
  if callback is not None:
471
+ callback(i, latents[0], False)
472
 
473
  x0 = latents
474
 
wgp.py CHANGED
The diff for this file is too large to render. See raw diff