DeepBeepMeep commited on
Commit
453823e
·
1 Parent(s): 068831e

Added support for fantasyspeaking model

Browse files
README.md CHANGED
@@ -10,6 +10,7 @@
10
 
11
 
12
  ## 🔥 Latest News!!
 
13
  * April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
14
  * April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
15
 
@@ -303,7 +304,13 @@ Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gr
303
 
304
  There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
305
 
306
- It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
 
 
 
 
 
 
307
 
308
  ### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
309
  With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
 
10
 
11
 
12
  ## 🔥 Latest News!!
13
+ * May 5 2025: 👋 Wan 2.1GP v4.5: FantasySpeaking model, you can animate a talking head using a voice track. This works not only on people but also on objects. Also better seamless transitions between Vace sliding windows for very long videos (see recommended settings). New high quality processing features (mixed 16/32 bits calculation and 32 bitsVAE)
14
  * April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
15
  * April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
16
 
 
304
 
305
  There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
306
 
307
+ It seems you will get better results with Vace if you turn on "Skip Layer Guidance" with its default configuration.
308
+
309
+ Other recommended setttings for Vace:
310
+ - Use a long prompt description especially for the people / objects that are in the background and not in reference images. This will ensure consistency between the windows.
311
+ - Set a medium size overlap window: long enough to give the model a sense of the motion but short enough so any overlapped blurred frames do no turn the rest of the video into a blurred video
312
+ - Truncate at least the last 4 frames of the each generated window as Vace last frames tends to be blurry
313
+
314
 
315
  ### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
316
  With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
fantasytalking/infer.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Alibaba Inc. All Rights Reserved.
2
+
3
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
4
+
5
+ from .model import FantasyTalkingAudioConditionModel
6
+ from .utils import get_audio_features
7
+
8
+
9
+ def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
10
+ fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
11
+ from mmgp import offload
12
+ from accelerate import init_empty_weights
13
+ from fantasytalking.model import AudioProjModel
14
+ with init_empty_weights():
15
+ proj_model = AudioProjModel( 768, 2048)
16
+ offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
17
+ proj_model.to(device).eval().requires_grad_(False)
18
+
19
+ wav2vec_model_dir = "ckpts/wav2vec"
20
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
21
+ wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False)
22
+ audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
23
+
24
+ audio_proj_fea = proj_model(audio_wav2vec_fea)
25
+ pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
26
+ audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
27
+ return audio_proj_split, audio_context_lens
fantasytalking/model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from wan.modules.attention import pay_attention
5
+
6
+
7
+ class AudioProjModel(nn.Module):
8
+ def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
9
+ super().__init__()
10
+ self.cross_attention_dim = cross_attention_dim
11
+ self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
12
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
13
+
14
+ def forward(self, audio_embeds):
15
+ context_tokens = self.proj(audio_embeds)
16
+ context_tokens = self.norm(context_tokens)
17
+ return context_tokens # [B,L,C]
18
+
19
+ class WanCrossAttentionProcessor(nn.Module):
20
+ def __init__(self, context_dim, hidden_dim):
21
+ super().__init__()
22
+
23
+ self.context_dim = context_dim
24
+ self.hidden_dim = hidden_dim
25
+
26
+ self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
27
+ self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
28
+
29
+ nn.init.zeros_(self.k_proj.weight)
30
+ nn.init.zeros_(self.v_proj.weight)
31
+
32
+ def __call__(
33
+ self,
34
+ q: torch.Tensor,
35
+ audio_proj: torch.Tensor,
36
+ latents_num_frames: int = 21,
37
+ audio_context_lens = None
38
+ ) -> torch.Tensor:
39
+ """
40
+ audio_proj: [B, 21, L3, C]
41
+ audio_context_lens: [B*21].
42
+ """
43
+ b, l, n, d = q.shape
44
+
45
+ if len(audio_proj.shape) == 4:
46
+ audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
47
+ ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
48
+ ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
49
+ qkv_list = [audio_q, ip_key, ip_value]
50
+ del q, audio_q, ip_key, ip_value
51
+ audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
52
+ audio_x = audio_x.view(b, l, n, d)
53
+ audio_x = audio_x.flatten(2)
54
+ elif len(audio_proj.shape) == 3:
55
+ ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
56
+ ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
57
+ qkv_list = [q, ip_key, ip_value]
58
+ del q, ip_key, ip_value
59
+ audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
60
+ audio_x = audio_x.flatten(2)
61
+ return audio_x
62
+
63
+
64
+ class FantasyTalkingAudioConditionModel(nn.Module):
65
+ def __init__(self, wan_dit, audio_in_dim: int, audio_proj_dim: int):
66
+ super().__init__()
67
+
68
+ self.audio_in_dim = audio_in_dim
69
+ self.audio_proj_dim = audio_proj_dim
70
+
71
+ def split_audio_sequence(self, audio_proj_length, num_frames=81):
72
+ """
73
+ Map the audio feature sequence to corresponding latent frame slices.
74
+
75
+ Args:
76
+ audio_proj_length (int): The total length of the audio feature sequence
77
+ (e.g., 173 in audio_proj[1, 173, 768]).
78
+ num_frames (int): The number of video frames in the training data (default: 81).
79
+
80
+ Returns:
81
+ list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
82
+ (within the audio feature sequence) corresponding to a latent frame.
83
+ """
84
+ # Average number of tokens per original video frame
85
+ tokens_per_frame = audio_proj_length / num_frames
86
+
87
+ # Each latent frame covers 4 video frames, and we want the center
88
+ tokens_per_latent_frame = tokens_per_frame * 4
89
+ half_tokens = int(tokens_per_latent_frame / 2)
90
+
91
+ pos_indices = []
92
+ for i in range(int((num_frames - 1) / 4) + 1):
93
+ if i == 0:
94
+ pos_indices.append(0)
95
+ else:
96
+ start_token = tokens_per_frame * ((i - 1) * 4 + 1)
97
+ end_token = tokens_per_frame * (i * 4 + 1)
98
+ center_token = int((start_token + end_token) / 2) - 1
99
+ pos_indices.append(center_token)
100
+
101
+ # Build index ranges centered around each position
102
+ pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
103
+
104
+ # Adjust the first range to avoid negative start index
105
+ pos_idx_ranges[0] = [
106
+ -(half_tokens * 2 - pos_idx_ranges[1][0]),
107
+ pos_idx_ranges[1][0],
108
+ ]
109
+
110
+ return pos_idx_ranges
111
+
112
+ def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
113
+ """
114
+ Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
115
+ if the range exceeds the input boundaries.
116
+
117
+ Args:
118
+ input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
119
+ pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
120
+ expand_length (int): Number of tokens to expand on both sides of each subsequence.
121
+
122
+ Returns:
123
+ sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
124
+ Each element is a padded subsequence.
125
+ k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
126
+ Useful for ignoring padding tokens in attention masks.
127
+ """
128
+ pos_idx_ranges = [
129
+ [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
130
+ ]
131
+ sub_sequences = []
132
+ seq_len = input_tensor.size(1) # 173
133
+ max_valid_idx = seq_len - 1 # 172
134
+ k_lens_list = []
135
+ for start, end in pos_idx_ranges:
136
+ # Calculate the fill amount
137
+ pad_front = max(-start, 0)
138
+ pad_back = max(end - max_valid_idx, 0)
139
+
140
+ # Calculate the start and end indices of the valid part
141
+ valid_start = max(start, 0)
142
+ valid_end = min(end, max_valid_idx)
143
+
144
+ # Extract the valid part
145
+ if valid_start <= valid_end:
146
+ valid_part = input_tensor[:, valid_start : valid_end + 1, :]
147
+ else:
148
+ valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
149
+
150
+ # In the sequence dimension (the 1st dimension) perform padding
151
+ padded_subseq = F.pad(
152
+ valid_part,
153
+ (0, 0, 0, pad_back + pad_front, 0, 0),
154
+ mode="constant",
155
+ value=0,
156
+ )
157
+ k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
158
+
159
+ sub_sequences.append(padded_subseq)
160
+ return torch.stack(sub_sequences, dim=1), torch.tensor(
161
+ k_lens_list, dtype=torch.long
162
+ )
fantasytalking/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Alibaba Inc. All Rights Reserved.
2
+
3
+ import imageio
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+
11
+ def resize_image_by_longest_edge(image_path, target_size):
12
+ image = Image.open(image_path).convert("RGB")
13
+ width, height = image.size
14
+ scale = target_size / max(width, height)
15
+ new_size = (int(width * scale), int(height * scale))
16
+ return image.resize(new_size, Image.LANCZOS)
17
+
18
+
19
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
20
+ writer = imageio.get_writer(
21
+ save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
22
+ )
23
+ for frame in tqdm(frames, desc="Saving video"):
24
+ frame = np.array(frame)
25
+ writer.append_data(frame)
26
+ writer.close()
27
+
28
+
29
+ def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
30
+ sr = 16000
31
+ audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
32
+
33
+ start_time = 0
34
+ # end_time = (0 + (num_frames - 1) * 1) / fps
35
+ end_time = num_frames / fps
36
+
37
+ start_sample = int(start_time * sr)
38
+ end_sample = int(end_time * sr)
39
+
40
+ try:
41
+ audio_segment = audio_input[start_sample:end_sample]
42
+ except:
43
+ audio_segment = audio_input
44
+
45
+ input_values = audio_processor(
46
+ audio_segment, sampling_rate=sample_rate, return_tensors="pt"
47
+ ).input_values.to("cuda")
48
+
49
+ with torch.no_grad():
50
+ fea = wav2vec(input_values).last_hidden_state
51
+
52
+ return fea
requirements.txt CHANGED
@@ -16,7 +16,7 @@ gradio==5.23.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
- mmgp==3.4.1
20
  peft==0.14.0
21
  mutagen
22
  pydantic==2.10.6
@@ -28,4 +28,5 @@ timm
28
  segment-anything
29
  omegaconf
30
  hydra-core
 
31
  # rembg==2.0.65
 
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
+ mmgp==3.4.2
20
  peft==0.14.0
21
  mutagen
22
  pydantic==2.10.6
 
28
  segment-anything
29
  omegaconf
30
  hydra-core
31
+ librosa
32
  # rembg==2.0.65
wan/configs/__init__.py CHANGED
@@ -44,6 +44,8 @@ SUPPORTED_SIZES = {
44
  VACE_SIZE_CONFIGS = {
45
  '480*832': (480, 832),
46
  '832*480': (832, 480),
 
 
47
  }
48
 
49
  VACE_MAX_AREA_CONFIGS = {
 
44
  VACE_SIZE_CONFIGS = {
45
  '480*832': (480, 832),
46
  '832*480': (832, 480),
47
+ '720*1280': (720, 1280),
48
+ '1280*720': (1280, 720),
49
  }
50
 
51
  VACE_MAX_AREA_CONFIGS = {
wan/diffusion_forcing.py CHANGED
@@ -56,16 +56,18 @@ class DTT2V:
56
  vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
57
  device=self.device)
58
 
59
- logging.info(f"Creating WanModel from {model_filename}")
60
  from mmgp import offload
61
  # model_filename = "model.safetensors"
62
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json")
 
63
  # offload.load_model_data(self.model, "recam.ckpt")
64
  # self.model.cpu()
65
- self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
 
66
  offload.change_dtype(self.model, dtype, True)
67
  # offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
68
- # offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_xbf16_int8.safetensors", do_quantize= True, config_file_path="config.json")
69
  # offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
70
 
71
  self.model.eval().requires_grad_(False)
@@ -200,6 +202,9 @@ class DTT2V:
200
  fps: int = 24,
201
  VAE_tile_size = 0,
202
  joint_pass = False,
 
 
 
203
  callback = None,
204
  ):
205
  self._interrupt = False
@@ -211,6 +216,7 @@ class DTT2V:
211
 
212
  if ar_step == 0:
213
  causal_block_size = 1
 
214
 
215
  i2v_extra_kwrags = {}
216
  prefix_video = None
@@ -252,31 +258,33 @@ class DTT2V:
252
  prefix_video = output_video.to(self.device)
253
  else:
254
  causal_block_size = 1
 
255
  ar_step = 0
256
  prefix_video = image
257
  prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
258
  if prefix_video.dtype == torch.uint8:
259
  prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
260
  prefix_video = prefix_video.to(self.device)
261
- prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
262
- predix_video_latent_length = prefix_video[0].shape[1]
263
  truncate_len = predix_video_latent_length % causal_block_size
264
  if truncate_len != 0:
265
  if truncate_len == predix_video_latent_length:
266
  causal_block_size = 1
 
 
267
  else:
268
  print("the length of prefix video is truncated for the casual block size alignment.")
269
  predix_video_latent_length -= truncate_len
270
- prefix_video[0] = prefix_video[0][:, : predix_video_latent_length]
271
 
272
  base_num_frames_iter = latent_length
273
  latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
274
  latents = self.prepare_latents(
275
  latent_shape, dtype=torch.float32, device=self.device, generator=generator
276
  )
277
- latents = [latents]
278
  if prefix_video is not None:
279
- latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
280
  step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
281
  base_num_frames_iter,
282
  init_timesteps,
@@ -298,6 +306,8 @@ class DTT2V:
298
  if callback != None:
299
  callback(-1, None, True, override_num_inference_steps = updated_num_steps)
300
  if self.model.enable_teacache:
 
 
301
  time_steps_comb = []
302
  self.model.num_steps = updated_num_steps
303
  for i, timestep_i in enumerate(step_matrix):
@@ -309,7 +319,7 @@ class DTT2V:
309
  self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
310
  del time_steps_comb
311
  from mmgp import offload
312
- freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False)
313
  kwrags = {
314
  "freqs" :freqs,
315
  "fps" : fps_embeds,
@@ -320,27 +330,27 @@ class DTT2V:
320
  }
321
  kwrags.update(i2v_extra_kwrags)
322
 
323
-
324
  for i, timestep_i in enumerate(tqdm(step_matrix)):
 
 
325
  offload.set_step_no_for_lora(self.model, i)
326
  update_mask_i = step_update_mask[i]
327
  valid_interval_start, valid_interval_end = valid_interval[i]
328
  timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
329
- latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
330
  if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
331
  noise_factor = 0.001 * addnoise_condition
332
  timestep_for_noised_condition = addnoise_condition
333
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
334
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
335
  * (1.0 - noise_factor)
336
  + torch.randn_like(
337
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
338
  )
339
  * noise_factor
340
  )
341
  timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
342
  kwrags.update({
343
- "x" : torch.stack([latent_model_input[0]]),
344
  "t" : timestep,
345
  "current_step" : i,
346
  })
@@ -349,6 +359,7 @@ class DTT2V:
349
  if True:
350
  if not self.do_classifier_free_guidance:
351
  noise_pred = self.model(
 
352
  context=[prompt_embeds],
353
  **kwrags,
354
  )[0]
@@ -358,6 +369,7 @@ class DTT2V:
358
  else:
359
  if joint_pass:
360
  noise_pred_cond, noise_pred_uncond = self.model(
 
361
  context= [prompt_embeds, negative_prompt_embeds],
362
  **kwrags,
363
  )
@@ -365,12 +377,16 @@ class DTT2V:
365
  return None
366
  else:
367
  noise_pred_cond = self.model(
 
 
368
  context=[prompt_embeds],
369
  **kwrags,
370
  )[0]
371
  if self._interrupt:
372
  return None
373
  noise_pred_uncond = self.model(
 
 
374
  context=[negative_prompt_embeds],
375
  **kwrags,
376
  )[0]
@@ -380,18 +396,18 @@ class DTT2V:
380
  del noise_pred_cond, noise_pred_uncond
381
  for idx in range(valid_interval_start, valid_interval_end):
382
  if update_mask_i[idx].item():
383
- latents[0][:, idx] = sample_schedulers[idx].step(
384
  noise_pred[:, idx - valid_interval_start],
385
  timestep_i[idx],
386
- latents[0][:, idx],
387
  return_dict=False,
388
  generator=generator,
389
  )[0]
390
  sample_schedulers_counter[idx] += 1
391
  if callback is not None:
392
- callback(i, latents[0].squeeze(0), False)
393
 
394
- x0 = latents[0].unsqueeze(0)
395
  videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
396
  output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
397
  return output_video
 
56
  vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
57
  device=self.device)
58
 
59
+ logging.info(f"Creating WanModel from {model_filename[-1]}")
60
  from mmgp import offload
61
  # model_filename = "model.safetensors"
62
+ # model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors"
63
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json")
64
  # offload.load_model_data(self.model, "recam.ckpt")
65
  # self.model.cpu()
66
+ # dtype = torch.float16
67
+ self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
68
  offload.change_dtype(self.model, dtype, True)
69
  # offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
70
+ # offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", do_quantize= True, config_file_path="c:/temp/config _df720.json")
71
  # offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
72
 
73
  self.model.eval().requires_grad_(False)
 
202
  fps: int = 24,
203
  VAE_tile_size = 0,
204
  joint_pass = False,
205
+ slg_layers = None,
206
+ slg_start = 0.0,
207
+ slg_end = 1.0,
208
  callback = None,
209
  ):
210
  self._interrupt = False
 
216
 
217
  if ar_step == 0:
218
  causal_block_size = 1
219
+ causal_attention = False
220
 
221
  i2v_extra_kwrags = {}
222
  prefix_video = None
 
258
  prefix_video = output_video.to(self.device)
259
  else:
260
  causal_block_size = 1
261
+ causal_attention = False
262
  ar_step = 0
263
  prefix_video = image
264
  prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
265
  if prefix_video.dtype == torch.uint8:
266
  prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
267
  prefix_video = prefix_video.to(self.device)
268
+ prefix_video = self.vae.encode(prefix_video.unsqueeze(0))[0] # [(c, f, h, w)]
269
+ predix_video_latent_length = prefix_video.shape[1]
270
  truncate_len = predix_video_latent_length % causal_block_size
271
  if truncate_len != 0:
272
  if truncate_len == predix_video_latent_length:
273
  causal_block_size = 1
274
+ causal_attention = False
275
+ ar_step = 0
276
  else:
277
  print("the length of prefix video is truncated for the casual block size alignment.")
278
  predix_video_latent_length -= truncate_len
279
+ prefix_video = prefix_video[:, : predix_video_latent_length]
280
 
281
  base_num_frames_iter = latent_length
282
  latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
283
  latents = self.prepare_latents(
284
  latent_shape, dtype=torch.float32, device=self.device, generator=generator
285
  )
 
286
  if prefix_video is not None:
287
+ latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32)
288
  step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
289
  base_num_frames_iter,
290
  init_timesteps,
 
306
  if callback != None:
307
  callback(-1, None, True, override_num_inference_steps = updated_num_steps)
308
  if self.model.enable_teacache:
309
+ x_count = 2 if self.do_classifier_free_guidance else 1
310
+ self.model.previous_residual = [None] * x_count
311
  time_steps_comb = []
312
  self.model.num_steps = updated_num_steps
313
  for i, timestep_i in enumerate(step_matrix):
 
319
  self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
320
  del time_steps_comb
321
  from mmgp import offload
322
+ freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
323
  kwrags = {
324
  "freqs" :freqs,
325
  "fps" : fps_embeds,
 
330
  }
331
  kwrags.update(i2v_extra_kwrags)
332
 
 
333
  for i, timestep_i in enumerate(tqdm(step_matrix)):
334
+ kwrags["slg_layers"] = slg_layers if int(slg_start * updated_num_steps) <= i < int(slg_end * updated_num_steps) else None
335
+
336
  offload.set_step_no_for_lora(self.model, i)
337
  update_mask_i = step_update_mask[i]
338
  valid_interval_start, valid_interval_end = valid_interval[i]
339
  timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
340
+ latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
341
  if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
342
  noise_factor = 0.001 * addnoise_condition
343
  timestep_for_noised_condition = addnoise_condition
344
+ latent_model_input[:, valid_interval_start:predix_video_latent_length] = (
345
+ latent_model_input[:, valid_interval_start:predix_video_latent_length]
346
  * (1.0 - noise_factor)
347
  + torch.randn_like(
348
+ latent_model_input[:, valid_interval_start:predix_video_latent_length]
349
  )
350
  * noise_factor
351
  )
352
  timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
353
  kwrags.update({
 
354
  "t" : timestep,
355
  "current_step" : i,
356
  })
 
359
  if True:
360
  if not self.do_classifier_free_guidance:
361
  noise_pred = self.model(
362
+ x=[latent_model_input],
363
  context=[prompt_embeds],
364
  **kwrags,
365
  )[0]
 
369
  else:
370
  if joint_pass:
371
  noise_pred_cond, noise_pred_uncond = self.model(
372
+ x=[latent_model_input, latent_model_input],
373
  context= [prompt_embeds, negative_prompt_embeds],
374
  **kwrags,
375
  )
 
377
  return None
378
  else:
379
  noise_pred_cond = self.model(
380
+ x=[latent_model_input],
381
+ x_id=0,
382
  context=[prompt_embeds],
383
  **kwrags,
384
  )[0]
385
  if self._interrupt:
386
  return None
387
  noise_pred_uncond = self.model(
388
+ x=[latent_model_input],
389
+ x_id=1,
390
  context=[negative_prompt_embeds],
391
  **kwrags,
392
  )[0]
 
396
  del noise_pred_cond, noise_pred_uncond
397
  for idx in range(valid_interval_start, valid_interval_end):
398
  if update_mask_i[idx].item():
399
+ latents[:, idx] = sample_schedulers[idx].step(
400
  noise_pred[:, idx - valid_interval_start],
401
  timestep_i[idx],
402
+ latents[:, idx],
403
  return_dict=False,
404
  generator=generator,
405
  )[0]
406
  sample_schedulers_counter[idx] += 1
407
  if callback is not None:
408
+ callback(i, latents.squeeze(0), False)
409
 
410
+ x0 = latents.unsqueeze(0)
411
  videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
412
  output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
413
  return output_video
wan/image2video.py CHANGED
@@ -8,7 +8,7 @@ import sys
8
  import types
9
  from contextlib import contextmanager
10
  from functools import partial
11
-
12
  import numpy as np
13
  import torch
14
  import torch.cuda.amp as amp
@@ -84,13 +84,29 @@ class WanI2V:
84
  config.clip_checkpoint),
85
  tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
86
 
87
- logging.info(f"Creating WanModel from {model_filename}")
88
  from mmgp import offload
89
 
90
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
91
- self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
92
- offload.change_dtype(self.model, dtype, True)
93
- # offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
96
  self.model.eval().requires_grad_(False)
@@ -102,6 +118,8 @@ class WanI2V:
102
  input_prompt,
103
  img,
104
  img2 = None,
 
 
105
  max_area=720 * 1280,
106
  frame_num=81,
107
  shift=5.0,
@@ -119,7 +137,11 @@ class WanI2V:
119
  slg_end = 1.0,
120
  cfg_star_switch = True,
121
  cfg_zero_step = 5,
122
- add_frames_for_end_image = True
 
 
 
 
123
  ):
124
  r"""
125
  Generates video frames from input image and text prompt using diffusion process.
@@ -167,17 +189,25 @@ class WanI2V:
167
  frame_num +=1
168
  lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
169
 
 
170
  h, w = img.shape[1:]
171
- aspect_ratio = h / w
 
 
 
 
 
 
 
172
  lat_h = round(
173
- np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
174
  self.patch_size[1] * self.patch_size[1])
175
  lat_w = round(
176
- np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
177
  self.patch_size[2] * self.patch_size[2])
178
  h = lat_h * self.vae_stride[1]
179
  w = lat_w * self.vae_stride[2]
180
-
181
  clip_image_size = self.clip.model.image_size
182
  img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
183
  img = resize_lanczos(img, clip_image_size, clip_image_size)
@@ -271,98 +301,101 @@ class WanI2V:
271
 
272
  # sample videos
273
  latent = noise
274
- batch_size = latent.shape[0]
275
  freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
276
 
277
- arg_c = {
278
- 'context': [context],
279
- 'clip_fea': clip_context,
280
- 'y': [y],
281
- 'freqs' : freqs,
282
- 'pipeline' : self,
283
- 'callback' : callback
284
- }
285
-
286
- arg_null = {
287
- 'context': [context_null],
288
- 'clip_fea': clip_context,
289
- 'y': [y],
290
- 'freqs' : freqs,
291
- 'pipeline' : self,
292
- 'callback' : callback
293
- }
294
-
295
- arg_both= {
296
- 'context': [context, context_null],
297
- 'clip_fea': clip_context,
298
- 'y': [y],
299
- 'freqs' : freqs,
300
- 'pipeline' : self,
301
- 'callback' : callback
302
- }
303
 
304
  if self.model.enable_teacache:
 
305
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
306
 
307
  # self.model.to(self.device)
308
  if callback != None:
309
  callback(-1, None, True)
310
-
311
  for i, t in enumerate(tqdm(timesteps)):
312
  offload.set_step_no_for_lora(self.model, i)
313
- slg_layers_local = None
314
- if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
315
- slg_layers_local = slg_layers
316
-
317
- latent_model_input = [latent.to(self.device)]
318
  timestep = [t]
319
 
320
  timestep = torch.stack(timestep).to(self.device)
 
 
 
 
 
 
321
  if joint_pass:
322
- noise_pred_cond, noise_pred_uncond = self.model(
323
- latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
 
 
 
 
 
 
 
 
 
 
324
  if self._interrupt:
325
  return None
326
  else:
327
  noise_pred_cond = self.model(
328
- latent_model_input,
329
- t=timestep,
330
- current_step=i,
331
- is_uncond=False,
332
- **arg_c,
333
  )[0]
334
  if self._interrupt:
335
- return None
 
 
 
 
 
 
 
 
 
 
 
336
  noise_pred_uncond = self.model(
337
- latent_model_input,
338
- t=timestep,
339
- current_step=i,
340
- is_uncond=True,
341
- slg_layers=slg_layers_local,
342
- **arg_null,
343
  )[0]
344
  if self._interrupt:
345
  return None
346
  del latent_model_input
347
 
348
  # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
349
- noise_pred_text = noise_pred_cond
350
  if cfg_star_switch:
351
- positive_flat = noise_pred_text.view(batch_size, -1)
352
  negative_flat = noise_pred_uncond.view(batch_size, -1)
353
 
354
  alpha = optimized_scale(positive_flat,negative_flat)
355
  alpha = alpha.view(batch_size, 1, 1, 1)
356
 
357
-
358
  if (i <= cfg_zero_step):
359
- noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
360
  else:
361
  noise_pred_uncond *= alpha
362
- noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
363
-
364
- del noise_pred_uncond
365
-
 
366
  temp_x0 = sample_scheduler.step(
367
  noise_pred.unsqueeze(0),
368
  t,
@@ -376,9 +409,6 @@ class WanI2V:
376
  if callback is not None:
377
  callback(i, latent, False)
378
 
379
-
380
- # x0 = [latent.to(self.device, dtype=self.dtype)]
381
-
382
  x0 = [latent]
383
 
384
  # x0 = [lat_y]
 
8
  import types
9
  from contextlib import contextmanager
10
  from functools import partial
11
+ import json
12
  import numpy as np
13
  import torch
14
  import torch.cuda.amp as amp
 
84
  config.clip_checkpoint),
85
  tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
86
 
87
+ logging.info(f"Creating WanModel from {model_filename[-1]}")
88
  from mmgp import offload
89
 
90
+ # fantasy = torch.load("c:/temp/fantasy.ckpt")
91
+ # proj_model = fantasy["proj_model"]
92
+ # audio_processor = fantasy["audio_processor"]
93
+ # offload.safetensors2.torch_write_file(proj_model, "proj_model.safetensors")
94
+ # offload.safetensors2.torch_write_file(audio_processor, "audio_processor.safetensors")
95
+ # for k,v in audio_processor.items():
96
+ # audio_processor[k] = v.to(torch.bfloat16)
97
+ # with open("fantasy_config.json", "r", encoding="utf-8") as reader:
98
+ # config_text = reader.read()
99
+ # config_json = json.loads(config_text)
100
+ # offload.safetensors2.torch_write_file(audio_processor, "audio_processor_bf16.safetensors", config=config_json)
101
+ # model_filename = [model_filename, "audio_processor_bf16.safetensors"]
102
+ # model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors"
103
+ # dtype = torch.float16
104
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
105
+ self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
106
+ # offload.change_dtype(self.model, dtype, True)
107
+ # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
108
+ # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
109
+ # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
110
 
111
  # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
112
  self.model.eval().requires_grad_(False)
 
118
  input_prompt,
119
  img,
120
  img2 = None,
121
+ height =720,
122
+ width = 1280,
123
  max_area=720 * 1280,
124
  frame_num=81,
125
  shift=5.0,
 
137
  slg_end = 1.0,
138
  cfg_star_switch = True,
139
  cfg_zero_step = 5,
140
+ add_frames_for_end_image = True,
141
+ audio_scale=None,
142
+ audio_cfg_scale=None,
143
+ audio_proj=None,
144
+ audio_context_lens=None,
145
  ):
146
  r"""
147
  Generates video frames from input image and text prompt using diffusion process.
 
189
  frame_num +=1
190
  lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
191
 
192
+
193
  h, w = img.shape[1:]
194
+ # aspect_ratio = h / w
195
+
196
+ scale1 = min(height / h, width / w)
197
+ scale2 = min(height / h, width / w)
198
+ scale = max(scale1, scale2)
199
+ new_height = int(h * scale)
200
+ new_width = int(w * scale)
201
+
202
  lat_h = round(
203
+ new_height // self.vae_stride[1] //
204
  self.patch_size[1] * self.patch_size[1])
205
  lat_w = round(
206
+ new_width // self.vae_stride[2] //
207
  self.patch_size[2] * self.patch_size[2])
208
  h = lat_h * self.vae_stride[1]
209
  w = lat_w * self.vae_stride[2]
210
+
211
  clip_image_size = self.clip.model.image_size
212
  img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
213
  img = resize_lanczos(img, clip_image_size, clip_image_size)
 
301
 
302
  # sample videos
303
  latent = noise
304
+ batch_size = 1
305
  freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
306
 
307
+ kwargs = { 'clip_fea': clip_context, 'y': y, 'freqs' : freqs, 'pipeline' : self, 'callback' : callback }
308
+
309
+ if audio_proj != None:
310
+ kwargs.update({
311
+ "audio_proj": audio_proj.to(self.dtype),
312
+ "audio_context_lens": audio_context_lens,
313
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  if self.model.enable_teacache:
316
+ self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
317
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
318
 
319
  # self.model.to(self.device)
320
  if callback != None:
321
  callback(-1, None, True)
322
+ latent = latent.to(self.device)
323
  for i, t in enumerate(tqdm(timesteps)):
324
  offload.set_step_no_for_lora(self.model, i)
325
+ kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
326
+ latent_model_input = latent
 
 
 
327
  timestep = [t]
328
 
329
  timestep = torch.stack(timestep).to(self.device)
330
+ kwargs.update({
331
+ 't' :timestep,
332
+ 'current_step' :i,
333
+ })
334
+
335
+
336
  if joint_pass:
337
+ if audio_proj == None:
338
+ noise_pred_cond, noise_pred_uncond = self.model(
339
+ [latent_model_input, latent_model_input],
340
+ context=[context, context_null],
341
+ **kwargs)
342
+ else:
343
+ noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = self.model(
344
+ [latent_model_input, latent_model_input, latent_model_input],
345
+ context=[context, context, context_null],
346
+ audio_scale = [audio_scale, None, None ],
347
+ **kwargs)
348
+
349
  if self._interrupt:
350
  return None
351
  else:
352
  noise_pred_cond = self.model(
353
+ [latent_model_input],
354
+ context=[context],
355
+ audio_scale = None if audio_scale == None else [audio_scale],
356
+ x_id=0,
357
+ **kwargs,
358
  )[0]
359
  if self._interrupt:
360
+ return None
361
+
362
+ if audio_proj != None:
363
+ noise_pred_noaudio = self.model(
364
+ [latent_model_input],
365
+ x_id=1,
366
+ context=[context],
367
+ **kwargs,
368
+ )[0]
369
+ if self._interrupt:
370
+ return None
371
+
372
  noise_pred_uncond = self.model(
373
+ [latent_model_input],
374
+ x_id=1 if audio_scale == None else 2,
375
+ context=[context_null],
376
+ **kwargs,
 
 
377
  )[0]
378
  if self._interrupt:
379
  return None
380
  del latent_model_input
381
 
382
  # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
 
383
  if cfg_star_switch:
384
+ positive_flat = noise_pred_cond.view(batch_size, -1)
385
  negative_flat = noise_pred_uncond.view(batch_size, -1)
386
 
387
  alpha = optimized_scale(positive_flat,negative_flat)
388
  alpha = alpha.view(batch_size, 1, 1, 1)
389
 
 
390
  if (i <= cfg_zero_step):
391
+ noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
392
  else:
393
  noise_pred_uncond *= alpha
394
+ if audio_scale == None:
395
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
396
+ else:
397
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
398
+ noise_pred_uncond, noise_pred_noaudio = None, None
399
  temp_x0 = sample_scheduler.step(
400
  noise_pred.unsqueeze(0),
401
  t,
 
409
  if callback is not None:
410
  callback(i, latent, False)
411
 
 
 
 
412
  x0 = [latent]
413
 
414
  # x0 = [lat_y]
wan/modules/attention.py CHANGED
@@ -57,15 +57,15 @@ def sageattn_wrapper(
57
  ):
58
  q,k, v = qkv_list
59
  padding_length = q.shape[0] -attention_length
60
- q = q[:attention_length, :, : ].unsqueeze(0)
61
- k = k[:attention_length, :, : ].unsqueeze(0)
62
- v = v[:attention_length, :, : ].unsqueeze(0)
63
  if True:
64
  qkv_list = [q,k,v]
65
  del q, k ,v
66
- o = alt_sageattn(qkv_list, tensor_layout="NHD").squeeze(0)
67
  else:
68
- o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
69
  del q, k ,v
70
 
71
  qkv_list.clear()
@@ -107,14 +107,14 @@ def sdpa_wrapper(
107
  attention_length
108
  ):
109
  q,k, v = qkv_list
110
- padding_length = q.shape[0] -attention_length
111
- q = q[:attention_length, :].transpose(0,1).unsqueeze(0)
112
- k = k[:attention_length, :].transpose(0,1).unsqueeze(0)
113
- v = v[:attention_length, :].transpose(0,1).unsqueeze(0)
114
 
115
  o = F.scaled_dot_product_attention(
116
  q, k, v, attn_mask=None, is_causal=False
117
- ).squeeze(0).transpose(0,1)
118
  del q, k ,v
119
  qkv_list.clear()
120
 
@@ -159,36 +159,72 @@ def pay_attention(
159
  deterministic=False,
160
  version=None,
161
  force_attention= None,
162
- cross_attn= False
 
163
  ):
164
 
165
  attn = offload.shared_state["_attention"] if force_attention== None else force_attention
166
  q,k,v = qkv_list
167
  qkv_list.clear()
168
 
169
-
170
  # params
171
  b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
172
- assert b==1
173
- q = q.squeeze(0)
174
- k = k.squeeze(0)
175
- v = v.squeeze(0)
176
-
177
 
178
  q = q.to(v.dtype)
179
  k = k.to(v.dtype)
180
-
181
- # if q_scale is not None:
182
- # q = q * q_scale
183
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
185
  warnings.warn(
186
  'Flash attention 3 is not available, use flash attention 2 instead.'
187
  )
188
 
189
  if attn=="sage" or attn=="flash":
190
- cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
191
- cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  # apply attention
194
  if attn=="sage":
@@ -207,7 +243,7 @@ def pay_attention(
207
  qkv_list = [q,k,v]
208
  del q,k,v
209
 
210
- x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
211
  # else:
212
  # layer = offload.shared_state["layer"]
213
  # embed_sizes = offload.shared_state["embed_sizes"]
@@ -267,8 +303,8 @@ def pay_attention(
267
 
268
  elif attn=="sdpa":
269
  qkv_list = [q, k, v]
270
- del q, k , v
271
- x = sdpa_wrapper( qkv_list, lq).unsqueeze(0)
272
  elif attn=="flash" and version == 3:
273
  # Note: dropout_p, window_size are not supported in FA3 now.
274
  x = flash_attn_interface.flash_attn_varlen_func(
@@ -302,59 +338,11 @@ def pay_attention(
302
  # output
303
 
304
  elif attn=="xformers":
305
- x = memory_efficient_attention(
306
- q.unsqueeze(0),
307
- k.unsqueeze(0),
308
- v.unsqueeze(0),
309
- ) #.unsqueeze(0)
 
310
 
311
- return x.type(out_dtype)
312
-
313
-
314
- def attention(
315
- q,
316
- k,
317
- v,
318
- q_lens=None,
319
- k_lens=None,
320
- dropout_p=0.,
321
- softmax_scale=None,
322
- q_scale=None,
323
- causal=False,
324
- window_size=(-1, -1),
325
- deterministic=False,
326
- dtype=torch.bfloat16,
327
- fa_version=None,
328
- ):
329
- if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
330
- return pay_attention(
331
- q=q,
332
- k=k,
333
- v=v,
334
- q_lens=q_lens,
335
- k_lens=k_lens,
336
- dropout_p=dropout_p,
337
- softmax_scale=softmax_scale,
338
- q_scale=q_scale,
339
- causal=causal,
340
- window_size=window_size,
341
- deterministic=deterministic,
342
- dtype=dtype,
343
- version=fa_version,
344
- )
345
- else:
346
- if q_lens is not None or k_lens is not None:
347
- warnings.warn(
348
- 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
349
- )
350
- attn_mask = None
351
-
352
- q = q.transpose(1, 2).to(dtype)
353
- k = k.transpose(1, 2).to(dtype)
354
- v = v.transpose(1, 2).to(dtype)
355
-
356
- out = torch.nn.functional.scaled_dot_product_attention(
357
- q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
358
-
359
- out = out.transpose(1, 2).contiguous()
360
- return out
 
57
  ):
58
  q,k, v = qkv_list
59
  padding_length = q.shape[0] -attention_length
60
+ q = q[:attention_length, :, : ]
61
+ k = k[:attention_length, :, : ]
62
+ v = v[:attention_length, :, : ]
63
  if True:
64
  qkv_list = [q,k,v]
65
  del q, k ,v
66
+ o = alt_sageattn(qkv_list, tensor_layout="NHD")
67
  else:
68
+ o = sageattn(q, k, v, tensor_layout="NHD")
69
  del q, k ,v
70
 
71
  qkv_list.clear()
 
107
  attention_length
108
  ):
109
  q,k, v = qkv_list
110
+ padding_length = q.shape[1] -attention_length
111
+ q = q[:attention_length, :].transpose(1,2)
112
+ k = k[:attention_length, :].transpose(1,2)
113
+ v = v[:attention_length, :].transpose(1,2)
114
 
115
  o = F.scaled_dot_product_attention(
116
  q, k, v, attn_mask=None, is_causal=False
117
+ ).transpose(1,2)
118
  del q, k ,v
119
  qkv_list.clear()
120
 
 
159
  deterministic=False,
160
  version=None,
161
  force_attention= None,
162
+ cross_attn= False,
163
+ k_lens = None
164
  ):
165
 
166
  attn = offload.shared_state["_attention"] if force_attention== None else force_attention
167
  q,k,v = qkv_list
168
  qkv_list.clear()
169
 
 
170
  # params
171
  b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
 
 
 
 
 
172
 
173
  q = q.to(v.dtype)
174
  k = k.to(v.dtype)
175
+ if b > 0 and k_lens != None and attn in ("sage2", "sdpa"):
176
+ # Poor's man var len attention
177
+ chunk_sizes = []
178
+ k_sizes = []
179
+ current_size = k_lens[0]
180
+ current_count= 1
181
+ for k_len in k_lens[1:]:
182
+ if k_len == current_size:
183
+ current_count += 1
184
+ else:
185
+ chunk_sizes.append(current_count)
186
+ k_sizes.append(current_size)
187
+ current_count = 1
188
+ current_size = k_len
189
+ chunk_sizes.append(current_count)
190
+ k_sizes.append(k_len)
191
+ if len(chunk_sizes) > 1 or k_lens[0] != k.shape[1]:
192
+ q_chunks =torch.split(q, chunk_sizes)
193
+ k_chunks =torch.split(k, chunk_sizes)
194
+ v_chunks =torch.split(v, chunk_sizes)
195
+ q, k, v = None, None, None
196
+ k_chunks = [ u[:, :sz] for u, sz in zip(k_chunks, k_sizes)]
197
+ v_chunks = [ u[:, :sz] for u, sz in zip(v_chunks, k_sizes)]
198
+ o = []
199
+ for sub_q, sub_k, sub_v in zip(q_chunks, k_chunks, v_chunks):
200
+ qkv_list = [sub_q, sub_k, sub_v]
201
+ sub_q, sub_k, sub_v = None, None, None
202
+ o.append( pay_attention(qkv_list) )
203
+ q_chunks, k_chunks, v_chunks = None, None, None
204
+ o = torch.cat(o, dim = 0)
205
+ return o
206
  if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
207
  warnings.warn(
208
  'Flash attention 3 is not available, use flash attention 2 instead.'
209
  )
210
 
211
  if attn=="sage" or attn=="flash":
212
+ if b != 1 :
213
+ if k_lens == None:
214
+ k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
215
+ k = torch.cat([u[:v] for u, v in zip(k, k_lens)])
216
+ v = torch.cat([u[:v] for u, v in zip(v, k_lens)])
217
+ q = q.reshape(-1, *q.shape[-2:])
218
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
219
+ cu_seqlens_q=torch.cat([k_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
220
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
221
+ else:
222
+ cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
223
+ cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
224
+ q = q.squeeze(0)
225
+ k = k.squeeze(0)
226
+ v = v.squeeze(0)
227
+
228
 
229
  # apply attention
230
  if attn=="sage":
 
243
  qkv_list = [q,k,v]
244
  del q,k,v
245
 
246
+ x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
247
  # else:
248
  # layer = offload.shared_state["layer"]
249
  # embed_sizes = offload.shared_state["embed_sizes"]
 
303
 
304
  elif attn=="sdpa":
305
  qkv_list = [q, k, v]
306
+ del q ,k ,v
307
+ x = sdpa_wrapper( qkv_list, lq) #.unsqueeze(0)
308
  elif attn=="flash" and version == 3:
309
  # Note: dropout_p, window_size are not supported in FA3 now.
310
  x = flash_attn_interface.flash_attn_varlen_func(
 
338
  # output
339
 
340
  elif attn=="xformers":
341
+ from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
342
+ if b != 1 and k_lens != None:
343
+ attn_mask = BlockDiagonalPaddedKeysMask.from_seqlens([lq] * b , lk, list(k_lens) )
344
+ x = memory_efficient_attention(q, k, v, attn_bias= attn_mask )
345
+ else:
346
+ x = memory_efficient_attention(q, k, v )
347
 
348
+ return x.type(out_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/modules/model.py CHANGED
@@ -197,9 +197,9 @@ class WanSelfAttention(nn.Module):
197
  del q,k
198
 
199
  q,k = apply_rotary_emb(qklist, freqs, head_first=False)
200
- qkv_list = [q,k,v]
201
- del q,k,v
202
  if block_mask == None:
 
 
203
  x = pay_attention(
204
  qkv_list,
205
  window_size=self.window_size)
@@ -212,6 +212,7 @@ class WanSelfAttention(nn.Module):
212
  .transpose(1, 2)
213
  .contiguous()
214
  )
 
215
 
216
  # if not self._flag_ar_attention:
217
  # q = rope_apply(q, grid_sizes, freqs)
@@ -241,7 +242,7 @@ class WanSelfAttention(nn.Module):
241
 
242
  class WanT2VCrossAttention(WanSelfAttention):
243
 
244
- def forward(self, xlist, context):
245
  r"""
246
  Args:
247
  x(Tensor): Shape [B, L1, C]
@@ -262,6 +263,7 @@ class WanT2VCrossAttention(WanSelfAttention):
262
  v = self.v(context).view(b, -1, n, d)
263
 
264
  # compute attention
 
265
  qvl_list=[q, k, v]
266
  del q, k, v
267
  x = pay_attention(qvl_list, cross_attn= True)
@@ -287,7 +289,7 @@ class WanI2VCrossAttention(WanSelfAttention):
287
  # self.alpha = nn.Parameter(torch.zeros((1, )))
288
  self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
289
 
290
- def forward(self, xlist, context):
291
  r"""
292
  Args:
293
  x(Tensor): Shape [B, L1, C]
@@ -310,6 +312,8 @@ class WanI2VCrossAttention(WanSelfAttention):
310
  del x
311
  self.norm_q(q)
312
  q= q.view(b, -1, n, d)
 
 
313
  k = self.k(context)
314
  self.norm_k(k)
315
  k = k.view(b, -1, n, d)
@@ -334,6 +338,8 @@ class WanI2VCrossAttention(WanSelfAttention):
334
  img_x = img_x.flatten(2)
335
  x += img_x
336
  del img_x
 
 
337
  x = self.o(x)
338
  return x
339
 
@@ -398,7 +404,10 @@ class WanAttentionBlock(nn.Module):
398
  hints= None,
399
  context_scale=1.0,
400
  cam_emb= None,
401
- block_mask = None
 
 
 
402
  ):
403
  r"""
404
  Args:
@@ -433,7 +442,7 @@ class WanAttentionBlock(nn.Module):
433
  if cam_emb != None:
434
  cam_emb = self.cam_encoder(cam_emb)
435
  cam_emb = cam_emb.repeat(1, 2, 1)
436
- cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1)
437
  cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
438
  x_mod += cam_emb
439
 
@@ -453,7 +462,7 @@ class WanAttentionBlock(nn.Module):
453
  y = y.to(attention_dtype)
454
  ylist= [y]
455
  del y
456
- x += self.cross_attn(ylist, context).to(dtype)
457
 
458
  y = self.norm2(x)
459
 
@@ -610,6 +619,7 @@ class WanModel(ModelMixin, ConfigMixin):
610
  eps=1e-6,
611
  recammaster = False,
612
  inject_sample_info = False,
 
613
  ):
614
  r"""
615
  Initialize the diffusion model backbone.
@@ -742,43 +752,48 @@ class WanModel(ModelMixin, ConfigMixin):
742
  block.projector.weight = nn.Parameter(torch.eye(dim))
743
  block.projector.bias = nn.Parameter(torch.zeros(dim))
744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
 
746
- def lock_layers_dtypes(self, dtype = torch.float32, force = False):
747
- count = 0
748
- layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2],
749
- self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
750
  if hasattr(self, "fps_embedding"):
751
- layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
752
 
753
  if hasattr(self, "vace_patch_embedding"):
754
- layer_list += [self.vace_patch_embedding]
755
- layer_list += [self.vace_blocks[0].before_proj]
756
  for block in self.vace_blocks:
757
- layer_list += [block.after_proj, block.norm3]
 
 
758
 
759
  # cam master
760
  if hasattr(self.blocks[0], "projector"):
761
  for block in self.blocks:
762
- layer_list += [block.projector]
763
 
764
- for block in self.blocks:
765
- layer_list += [block.norm3]
766
- for layer in layer_list:
767
- if hasattr(layer, "weight"):
768
- if layer.weight.dtype == dtype :
769
- count += 1
770
- elif force:
771
- if hasattr(layer, "weight"):
772
- layer.weight.data = layer.weight.data.to(dtype)
773
- if hasattr(layer, "bias"):
774
- layer.bias.data = layer.bias.data.to(dtype)
775
- count += 1
776
-
777
- layer._lock_dtype = dtype
778
 
 
 
 
 
779
 
780
- if count > 0:
781
- self._lock_dtype = dtype
782
 
783
 
784
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
@@ -788,7 +803,7 @@ class WanModel(ModelMixin, ConfigMixin):
788
  t = torch.stack([t])
789
  time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
790
  e_list.append(time_emb)
791
-
792
  best_threshold = 0.01
793
  best_diff = 1000
794
  best_signed_diff = 1000
@@ -798,12 +813,16 @@ class WanModel(ModelMixin, ConfigMixin):
798
  accumulated_rel_l1_distance =0
799
  nb_steps = 0
800
  diff = 1000
 
801
  for i, t in enumerate(timesteps):
802
  skip = False
803
- if not (i<=start_step or i== len(timesteps)):
804
- accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
 
 
805
  if accumulated_rel_l1_distance < threshold:
806
  skip = True
 
807
  else:
808
  accumulated_rel_l1_distance = 0
809
  if not skip:
@@ -812,6 +831,7 @@ class WanModel(ModelMixin, ConfigMixin):
812
  diff = abs(signed_diff)
813
  if diff < best_diff:
814
  best_threshold = threshold
 
815
  best_diff = diff
816
  best_signed_diff = signed_diff
817
  elif diff > best_diff:
@@ -819,6 +839,7 @@ class WanModel(ModelMixin, ConfigMixin):
819
  threshold += 0.01
820
  self.rel_l1_thresh = best_threshold
821
  print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
 
822
  return best_threshold
823
 
824
 
@@ -834,7 +855,7 @@ class WanModel(ModelMixin, ConfigMixin):
834
  freqs = None,
835
  pipeline = None,
836
  current_step = 0,
837
- is_uncond=False,
838
  max_steps = 0,
839
  slg_layers=None,
840
  callback = None,
@@ -842,10 +863,13 @@ class WanModel(ModelMixin, ConfigMixin):
842
  fps = None,
843
  causal_block_size = 1,
844
  causal_attention = False,
845
- x_neg = None
 
 
 
846
  ):
847
- # dtype = self.blocks[0].self_attn.q.weight.dtype
848
- dtype = self.patch_embedding.weight.dtype
849
 
850
  if self.model_type == 'i2v':
851
  assert clip_fea is not None and y is not None
@@ -854,20 +878,32 @@ class WanModel(ModelMixin, ConfigMixin):
854
  if torch.is_tensor(freqs) and freqs.device != device:
855
  freqs = freqs.to(device)
856
 
857
- if y is not None:
858
- x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
859
 
860
- # embeddings
861
- x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x]
862
- if x_neg !=None:
863
- x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg]
864
-
865
- grid_sizes = [ list(u.shape[2:]) for u in x]
866
- embed_sizes = grid_sizes[0]
867
- if causal_attention : #causal_block_size > 0:
868
- frame_num = embed_sizes[0]
869
- height = embed_sizes[1]
870
- width = embed_sizes[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
  block_num = frame_num // causal_block_size
872
  range_tensor = torch.arange(block_num).view(-1, 1)
873
  range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
@@ -878,30 +914,21 @@ class WanModel(ModelMixin, ConfigMixin):
878
  block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
879
  del causal_mask
880
 
881
- offload.shared_state["embed_sizes"] = embed_sizes
882
  offload.shared_state["step_no"] = current_step
883
  offload.shared_state["max_steps"] = max_steps
884
 
885
- x = [u.flatten(2).transpose(1, 2) for u in x]
886
- x = x[0]
887
- if x_neg !=None:
888
- x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
889
- x_neg = x_neg[0]
890
 
891
- if t.dim() == 2:
892
- b, f = t.shape
893
- _flag_df = True
894
- else:
895
- _flag_df = False
896
  e = self.time_embedding(
897
- sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype) # self.patch_embedding.weight.dtype)
898
  ) # b, dim
899
  e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
900
 
901
  if self.inject_sample_info:
902
  fps = torch.tensor(fps, dtype=torch.long, device=device)
903
 
904
- fps_emb = self.fps_embedding(fps).to(dtype) # float()
905
  if _flag_df:
906
  e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
907
  else:
@@ -913,30 +940,28 @@ class WanModel(ModelMixin, ConfigMixin):
913
  if clip_fea is not None:
914
  context_clip = self.img_emb(clip_fea) # bs x 257 x dim
915
  context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
916
-
917
- joint_pass = len(context) > 0
918
- x_list = [x]
919
- if joint_pass:
920
- if x_neg == None:
921
- x_list += [x.clone() for i in range(len(context) - 1) ]
922
- else:
923
- x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg]
924
- is_uncond = False
925
- del x
926
  context_list = context
 
 
 
 
927
 
928
  # arguments
929
 
930
  kwargs = dict(
931
  grid_sizes=grid_sizes,
932
  freqs=freqs,
933
- cam_emb = cam_emb
 
 
 
934
  )
935
 
936
  if vace_context == None:
937
  hints_list = [None ] *len(x_list)
938
  else:
939
- # embeddings
940
  c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
941
  c = [u.flatten(2).transpose(1, 2) for u in c]
942
  c = c[0]
@@ -947,7 +972,7 @@ class WanModel(ModelMixin, ConfigMixin):
947
 
948
  should_calc = True
949
  if self.enable_teacache:
950
- if is_uncond:
951
  should_calc = self.should_calc
952
  else:
953
  if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
@@ -955,11 +980,12 @@ class WanModel(ModelMixin, ConfigMixin):
955
  self.accumulated_rel_l1_distance = 0
956
  else:
957
  rescale_func = np.poly1d(self.coefficients)
958
- self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
 
959
  if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
960
  should_calc = False
961
  self.teacache_skipped_steps += 1
962
- # print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" )
963
  else:
964
  should_calc = True
965
  self.accumulated_rel_l1_distance = 0
@@ -967,15 +993,23 @@ class WanModel(ModelMixin, ConfigMixin):
967
  self.should_calc = should_calc
968
 
969
  if not should_calc:
970
- for i, x in enumerate(x_list):
971
- x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond
 
 
 
 
 
972
  else:
973
  if self.enable_teacache:
974
- if joint_pass or is_uncond:
975
- self.previous_residual_uncond = None
976
- if joint_pass or not is_uncond:
977
- self.previous_residual_cond = None
978
- ori_hidden_states = x_list[0].clone()
 
 
 
979
 
980
  for block_idx, block in enumerate(self.blocks):
981
  offload.shared_state["layer"] = block_idx
@@ -984,29 +1018,30 @@ class WanModel(ModelMixin, ConfigMixin):
984
  if pipeline._interrupt:
985
  return [None] * len(x_list)
986
 
987
- if slg_layers is not None and block_idx in slg_layers:
988
- if is_uncond and not joint_pass:
989
  continue
990
  x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
991
-
992
  else:
993
- for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
994
- x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
995
  del x
996
  del context, hints
997
 
998
  if self.enable_teacache:
999
  if joint_pass:
1000
- self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states)
1001
- self.previous_residual_uncond = ori_hidden_states
1002
- torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond)
 
 
 
 
 
1003
  else:
1004
- residual = ori_hidden_states # just to have a readable code
1005
- torch.sub(x_list[0], ori_hidden_states, out=residual)
1006
- if i==1 or is_uncond:
1007
- self.previous_residual_uncond = residual
1008
- else:
1009
- self.previous_residual_cond = residual
1010
  residual, ori_hidden_states = None, None
1011
 
1012
  for i, x in enumerate(x_list):
@@ -1037,10 +1072,10 @@ class WanModel(ModelMixin, ConfigMixin):
1037
 
1038
  c = self.out_dim
1039
  out = []
1040
- for u, v in zip(x, grid_sizes):
1041
- u = u[:math.prod(v)].view(*v, *self.patch_size, c)
1042
  u = torch.einsum('fhwpqrc->cfphqwr', u)
1043
- u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
1044
  out.append(u)
1045
  return out
1046
 
 
197
  del q,k
198
 
199
  q,k = apply_rotary_emb(qklist, freqs, head_first=False)
 
 
200
  if block_mask == None:
201
+ qkv_list = [q,k,v]
202
+ del q,k,v
203
  x = pay_attention(
204
  qkv_list,
205
  window_size=self.window_size)
 
212
  .transpose(1, 2)
213
  .contiguous()
214
  )
215
+ del q,k,v
216
 
217
  # if not self._flag_ar_attention:
218
  # q = rope_apply(q, grid_sizes, freqs)
 
242
 
243
  class WanT2VCrossAttention(WanSelfAttention):
244
 
245
+ def forward(self, xlist, context, grid_sizes, *args, **kwargs):
246
  r"""
247
  Args:
248
  x(Tensor): Shape [B, L1, C]
 
263
  v = self.v(context).view(b, -1, n, d)
264
 
265
  # compute attention
266
+ v = v.contiguous().clone()
267
  qvl_list=[q, k, v]
268
  del q, k, v
269
  x = pay_attention(qvl_list, cross_attn= True)
 
289
  # self.alpha = nn.Parameter(torch.zeros((1, )))
290
  self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
291
 
292
+ def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ):
293
  r"""
294
  Args:
295
  x(Tensor): Shape [B, L1, C]
 
312
  del x
313
  self.norm_q(q)
314
  q= q.view(b, -1, n, d)
315
+ if audio_scale != None:
316
+ audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
317
  k = self.k(context)
318
  self.norm_k(k)
319
  k = k.view(b, -1, n, d)
 
338
  img_x = img_x.flatten(2)
339
  x += img_x
340
  del img_x
341
+ if audio_scale != None:
342
+ x.add_(audio_x, alpha= audio_scale)
343
  x = self.o(x)
344
  return x
345
 
 
404
  hints= None,
405
  context_scale=1.0,
406
  cam_emb= None,
407
+ block_mask = None,
408
+ audio_proj= None,
409
+ audio_context_lens= None,
410
+ audio_scale=None,
411
  ):
412
  r"""
413
  Args:
 
442
  if cam_emb != None:
443
  cam_emb = self.cam_encoder(cam_emb)
444
  cam_emb = cam_emb.repeat(1, 2, 1)
445
+ cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[1], grid_sizes[2], 1)
446
  cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
447
  x_mod += cam_emb
448
 
 
462
  y = y.to(attention_dtype)
463
  ylist= [y]
464
  del y
465
+ x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
466
 
467
  y = self.norm2(x)
468
 
 
619
  eps=1e-6,
620
  recammaster = False,
621
  inject_sample_info = False,
622
+ fantasytalking_dim = 0,
623
  ):
624
  r"""
625
  Initialize the diffusion model backbone.
 
752
  block.projector.weight = nn.Parameter(torch.eye(dim))
753
  block.projector.bias = nn.Parameter(torch.zeros(dim))
754
 
755
+ if fantasytalking_dim > 0:
756
+ from fantasytalking.model import WanCrossAttentionProcessor
757
+ for block in self.blocks:
758
+ block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim)
759
+
760
+
761
+ def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
762
+ layer_list = [self.head, self.head.head, self.patch_embedding]
763
+ target_dype= dtype
764
+
765
+ layer_list2 = [ self.time_embedding, self.time_embedding[0], self.time_embedding[2],
766
+ self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
767
+
768
+ for block in self.blocks:
769
+ layer_list2 += [block.norm3]
770
 
 
 
 
 
771
  if hasattr(self, "fps_embedding"):
772
+ layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
773
 
774
  if hasattr(self, "vace_patch_embedding"):
775
+ layer_list2 += [self.vace_patch_embedding]
776
+ layer_list2 += [self.vace_blocks[0].before_proj]
777
  for block in self.vace_blocks:
778
+ layer_list2 += [block.after_proj, block.norm3]
779
+
780
+ target_dype2 = hybrid_dtype if hybrid_dtype != None else dtype
781
 
782
  # cam master
783
  if hasattr(self.blocks[0], "projector"):
784
  for block in self.blocks:
785
+ layer_list2 += [block.projector]
786
 
787
+ for current_layer_list, current_dtype in zip([layer_list, layer_list2], [target_dype, target_dype2]):
788
+ for layer in current_layer_list:
789
+ layer._lock_dtype = dtype
 
 
 
 
 
 
 
 
 
 
 
790
 
791
+ if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
792
+ layer.weight.data = layer.weight.data.to(current_dtype)
793
+ if hasattr(layer, "bias"):
794
+ layer.bias.data = layer.bias.data.to(current_dtype)
795
 
796
+ self._lock_dtype = dtype
 
797
 
798
 
799
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
 
803
  t = torch.stack([t])
804
  time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
805
  e_list.append(time_emb)
806
+ best_deltas = None
807
  best_threshold = 0.01
808
  best_diff = 1000
809
  best_signed_diff = 1000
 
813
  accumulated_rel_l1_distance =0
814
  nb_steps = 0
815
  diff = 1000
816
+ deltas = []
817
  for i, t in enumerate(timesteps):
818
  skip = False
819
+ if not (i<=start_step or i== len(timesteps)-1):
820
+ delta = abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
821
+ # deltas.append(delta)
822
+ accumulated_rel_l1_distance += delta
823
  if accumulated_rel_l1_distance < threshold:
824
  skip = True
825
+ # deltas.append("SKIP")
826
  else:
827
  accumulated_rel_l1_distance = 0
828
  if not skip:
 
831
  diff = abs(signed_diff)
832
  if diff < best_diff:
833
  best_threshold = threshold
834
+ best_deltas = deltas
835
  best_diff = diff
836
  best_signed_diff = signed_diff
837
  elif diff > best_diff:
 
839
  threshold += 0.01
840
  self.rel_l1_thresh = best_threshold
841
  print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
842
+ # print(f"deltas:{best_deltas}")
843
  return best_threshold
844
 
845
 
 
855
  freqs = None,
856
  pipeline = None,
857
  current_step = 0,
858
+ x_id= 0,
859
  max_steps = 0,
860
  slg_layers=None,
861
  callback = None,
 
863
  fps = None,
864
  causal_block_size = 1,
865
  causal_attention = False,
866
+ audio_proj=None,
867
+ audio_context_lens=None,
868
+ audio_scale=None,
869
+
870
  ):
871
+ # patch_dtype = self.patch_embedding.weight.dtype
872
+ modulation_dtype = self.time_projection[1].weight.dtype
873
 
874
  if self.model_type == 'i2v':
875
  assert clip_fea is not None and y is not None
 
878
  if torch.is_tensor(freqs) and freqs.device != device:
879
  freqs = freqs.to(device)
880
 
 
 
881
 
882
+ x_list = x
883
+ joint_pass = len(x_list) > 1
884
+ is_source_x = [ x.data_ptr() == x_list[0].data_ptr() and i > 0 for i, x in enumerate(x_list) ]
885
+ last_x_idx = 0
886
+ for i, (is_source, x) in enumerate(zip(is_source_x, x_list)):
887
+ if is_source:
888
+ x_list[i] = x_list[0].clone()
889
+ last_x_idx = i
890
+ else:
891
+ # image source
892
+ if y is not None:
893
+ x = torch.cat([x, y], dim=0)
894
+ # embeddings
895
+ x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
896
+ grid_sizes = x.shape[2:]
897
+ x = x.flatten(2).transpose(1, 2)
898
+ x_list[i] = x
899
+ x, y = None, None
900
+
901
+
902
+ block_mask = None
903
+ if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
904
+ frame_num = grid_sizes[0]
905
+ height = grid_sizes[1]
906
+ width = grid_sizes[2]
907
  block_num = frame_num // causal_block_size
908
  range_tensor = torch.arange(block_num).view(-1, 1)
909
  range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
 
914
  block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
915
  del causal_mask
916
 
917
+ offload.shared_state["embed_sizes"] = grid_sizes
918
  offload.shared_state["step_no"] = current_step
919
  offload.shared_state["max_steps"] = max_steps
920
 
921
+ _flag_df = t.dim() == 2
 
 
 
 
922
 
 
 
 
 
 
923
  e = self.time_embedding(
924
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(modulation_dtype) # self.patch_embedding.weight.dtype)
925
  ) # b, dim
926
  e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
927
 
928
  if self.inject_sample_info:
929
  fps = torch.tensor(fps, dtype=torch.long, device=device)
930
 
931
+ fps_emb = self.fps_embedding(fps).to(e.dtype)
932
  if _flag_df:
933
  e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
934
  else:
 
940
  if clip_fea is not None:
941
  context_clip = self.img_emb(clip_fea) # bs x 257 x dim
942
  context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
943
+
 
 
 
 
 
 
 
 
 
944
  context_list = context
945
+ if audio_scale != None:
946
+ audio_scale_list = audio_scale
947
+ else:
948
+ audio_scale_list = [None] * len(x_list)
949
 
950
  # arguments
951
 
952
  kwargs = dict(
953
  grid_sizes=grid_sizes,
954
  freqs=freqs,
955
+ cam_emb = cam_emb,
956
+ block_mask = block_mask,
957
+ audio_proj=audio_proj,
958
+ audio_context_lens=audio_context_lens,
959
  )
960
 
961
  if vace_context == None:
962
  hints_list = [None ] *len(x_list)
963
  else:
964
+ # Vace embeddings
965
  c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
966
  c = [u.flatten(2).transpose(1, 2) for u in c]
967
  c = c[0]
 
972
 
973
  should_calc = True
974
  if self.enable_teacache:
975
+ if x_id != 0:
976
  should_calc = self.should_calc
977
  else:
978
  if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
 
980
  self.accumulated_rel_l1_distance = 0
981
  else:
982
  rescale_func = np.poly1d(self.coefficients)
983
+ delta = abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
984
+ self.accumulated_rel_l1_distance += delta
985
  if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
986
  should_calc = False
987
  self.teacache_skipped_steps += 1
988
+ # print(f"Teacache Skipped Step no {current_step} ({self.teacache_skipped_steps}/{current_step}), delta={delta}" )
989
  else:
990
  should_calc = True
991
  self.accumulated_rel_l1_distance = 0
 
993
  self.should_calc = should_calc
994
 
995
  if not should_calc:
996
+ if joint_pass:
997
+ for i, x in enumerate(x_list):
998
+ x += self.previous_residual[i]
999
+ else:
1000
+ x = x_list[0]
1001
+ x += self.previous_residual[x_id]
1002
+ x = None
1003
  else:
1004
  if self.enable_teacache:
1005
+ if joint_pass:
1006
+ self.previous_residual = [ None ] * len(self.previous_residual)
1007
+ else:
1008
+ self.previous_residual[x_id] = None
1009
+ ori_hidden_states = [ None ] * len(x_list)
1010
+ ori_hidden_states[0] = x_list[0].clone()
1011
+ for i in range(1, len(x_list)):
1012
+ ori_hidden_states[i] = ori_hidden_states[0] if is_source_x[i] else x_list[i].clone()
1013
 
1014
  for block_idx, block in enumerate(self.blocks):
1015
  offload.shared_state["layer"] = block_idx
 
1018
  if pipeline._interrupt:
1019
  return [None] * len(x_list)
1020
 
1021
+ if (x_id != 0 or joint_pass) and slg_layers is not None and block_idx in slg_layers:
1022
+ if not joint_pass:
1023
  continue
1024
  x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
 
1025
  else:
1026
+ for i, (x, context, hints, audio_scale) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list)):
1027
+ x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, e= e0, **kwargs)
1028
  del x
1029
  del context, hints
1030
 
1031
  if self.enable_teacache:
1032
  if joint_pass:
1033
+ for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
1034
+ if i == 0 or is_source and i != last_x_idx :
1035
+ self.previous_residual[i] = torch.sub(x, ori)
1036
+ else:
1037
+ self.previous_residual[i] = ori
1038
+ torch.sub(x, ori, out=self.previous_residual[i])
1039
+ ori_hidden_states[i] = None
1040
+ x , ori = None, None
1041
  else:
1042
+ residual = ori_hidden_states[0] # just to have a readable code
1043
+ torch.sub(x_list[0], ori_hidden_states[0], out=residual)
1044
+ self.previous_residual[x_id] = residual
 
 
 
1045
  residual, ori_hidden_states = None, None
1046
 
1047
  for i, x in enumerate(x_list):
 
1072
 
1073
  c = self.out_dim
1074
  out = []
1075
+ for u in x:
1076
+ u = u[:math.prod(grid_sizes)].view(*grid_sizes, *self.patch_size, c)
1077
  u = torch.einsum('fhwpqrc->cfphqwr', u)
1078
+ u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
1079
  out.append(u)
1080
  return out
1081
 
wan/modules/sage2_core.py CHANGED
@@ -140,7 +140,7 @@ def sageattn(
140
  elif arch == "sm90":
141
  return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
142
  elif arch == "sm120":
143
- return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
144
  else:
145
  raise ValueError(f"Unsupported CUDA architecture: {arch}")
146
 
 
140
  elif arch == "sm90":
141
  return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
142
  elif arch == "sm120":
143
+ return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32", smooth_v= True) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
144
  else:
145
  raise ValueError(f"Unsupported CUDA architecture: {arch}")
146
 
wan/text2video.py CHANGED
@@ -78,15 +78,16 @@ class WanT2V:
78
  vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
79
  device=self.device)
80
 
81
- logging.info(f"Creating WanModel from {model_filename}")
82
  from mmgp import offload
83
  # model_filename
 
84
  self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
85
  # offload.load_model_data(self.model, "e:/vace.safetensors")
86
  # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
87
  # self.model.to(torch.bfloat16)
88
  # self.model.cpu()
89
- self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
90
  offload.change_dtype(self.model, dtype, True)
91
  # offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
92
  # offload.save_model(self.model, "phantom_1.3B.safetensors")
@@ -95,7 +96,7 @@ class WanT2V:
95
 
96
  self.sample_neg_prompt = config.sample_neg_prompt
97
 
98
- if "Vace" in model_filename:
99
  self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
100
  min_area=480*832,
101
  max_area=480*832,
@@ -107,7 +108,7 @@ class WanT2V:
107
 
108
  self.adapt_vace_model()
109
 
110
- def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
111
  if ref_images is None:
112
  ref_images = [None] * len(frames)
113
  else:
@@ -119,6 +120,11 @@ class WanT2V:
119
  inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
120
  reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
121
  inactive = self.vae.encode(inactive, tile_size = tile_size)
 
 
 
 
 
122
  reactive = self.vae.encode(reactive, tile_size = tile_size)
123
  latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
124
 
@@ -288,7 +294,10 @@ class WanT2V:
288
  slg_end = 1.0,
289
  cfg_star_switch = True,
290
  cfg_zero_step = 5,
291
- ):
 
 
 
292
  r"""
293
  Generates video frames from text prompt using diffusion process.
294
 
@@ -343,20 +352,20 @@ class WanT2V:
343
  size = (source_video.shape[2], source_video.shape[1])
344
  source_video = source_video.to(dtype=self.dtype , device=self.device)
345
  source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
346
- source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
347
  del source_video
348
  # Process target camera (recammaster)
349
  from wan.utils.cammmaster_tools import get_camera_embedding
350
  cam_emb = get_camera_embedding(target_camera)
351
  cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
352
 
353
- if input_frames != None:
354
  # vace context encode
355
  input_frames = [u.to(self.device) for u in input_frames]
356
  input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
357
  input_masks = [u.to(self.device) for u in input_masks]
358
 
359
- z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
360
  m0 = self.vace_encode_masks(input_masks, input_ref_images)
361
  z = self.vace_latent(z0, m0)
362
 
@@ -365,10 +374,10 @@ class WanT2V:
365
  else:
366
  if input_ref_images != None: # Phantom Ref images
367
  phantom = True
368
- input_ref_images = [self.get_vae_latents(input_ref_images, self.device)]
369
- input_ref_images_neg = [torch.zeros_like(input_ref_images[0])]
370
  F = frame_num
371
- target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images[0].shape[1] if input_ref_images != None else 0),
372
  size[1] // self.vae_stride[1],
373
  size[0] // self.vae_stride[2])
374
 
@@ -405,37 +414,48 @@ class WanT2V:
405
  raise NotImplementedError("Unsupported solver.")
406
 
407
  # sample videos
408
- latents = noise
409
  del noise
410
- batch_size =len(latents)
411
  if target_camera != None:
412
- shape = list(latents[0].shape[1:])
413
  shape[0] *= 2
414
  freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
415
  else:
416
- freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
417
 
418
  kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
419
 
420
  if target_camera != None:
421
  kwargs.update({'cam_emb': cam_emb})
422
 
423
- if input_frames != None:
 
424
  kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
 
 
425
 
426
 
427
  if self.model.enable_teacache:
 
 
428
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
429
  if callback != None:
430
  callback(-1, None, True)
431
  for i, t in enumerate(tqdm(timesteps)):
 
 
 
 
 
 
 
432
  if target_camera != None:
433
- latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
434
  else:
435
  latent_model_input = latents
436
- slg_layers_local = None
437
- if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
438
- slg_layers_local = slg_layers
439
  timestep = [t]
440
  offload.set_step_no_for_lora(self.model, i)
441
  timestep = torch.stack(timestep)
@@ -444,38 +464,38 @@ class WanT2V:
444
  if joint_pass:
445
  if phantom:
446
  pos_it, pos_i, neg = self.model(
447
- [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)],
448
- x_neg = [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)],
449
  context = [context, context_null, context_null], **kwargs)
450
  else:
451
  noise_pred_cond, noise_pred_uncond = self.model(
452
- latent_model_input, slg_layers=slg_layers_local, context = [context, context_null], **kwargs)
453
  if self._interrupt:
454
  return None
455
  else:
456
  if phantom:
457
  pos_it = self.model(
458
- [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context], **kwargs
459
  )[0]
460
  if self._interrupt:
461
  return None
462
  pos_i = self.model(
463
- [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context_null],**kwargs
464
  )[0]
465
  if self._interrupt:
466
  return None
467
  neg = self.model(
468
- [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], context = [context_null], **kwargs
469
  )[0]
470
  if self._interrupt:
471
  return None
472
  else:
473
  noise_pred_cond = self.model(
474
- latent_model_input, is_uncond = False, context = [context], **kwargs)[0]
475
  if self._interrupt:
476
  return None
477
  noise_pred_uncond = self.model(
478
- latent_model_input, is_uncond = True, slg_layers=slg_layers_local,context = [context_null], **kwargs)[0]
479
  if self._interrupt:
480
  return None
481
 
@@ -505,21 +525,21 @@ class WanT2V:
505
  temp_x0 = sample_scheduler.step(
506
  noise_pred[:, :target_shape[1]].unsqueeze(0),
507
  t,
508
- latents[0].unsqueeze(0),
509
  return_dict=False,
510
  generator=seed_g)[0]
511
- latents = [temp_x0.squeeze(0)]
512
  del temp_x0
513
 
514
  if callback is not None:
515
- callback(i, latents[0], False)
516
 
517
- x0 = latents
518
 
519
  if input_frames == None:
520
  if phantom:
521
  # phantom post processing
522
- x0 = [x0_[:,:-input_ref_images[0].shape[1]] for x0_ in x0]
523
  videos = self.vae.decode(x0, VAE_tile_size)
524
  else:
525
  # vace post processing
 
78
  vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
79
  device=self.device)
80
 
81
+ logging.info(f"Creating WanModel from {model_filename[-1]}")
82
  from mmgp import offload
83
  # model_filename
84
+
85
  self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
86
  # offload.load_model_data(self.model, "e:/vace.safetensors")
87
  # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
88
  # self.model.to(torch.bfloat16)
89
  # self.model.cpu()
90
+ self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
91
  offload.change_dtype(self.model, dtype, True)
92
  # offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
93
  # offload.save_model(self.model, "phantom_1.3B.safetensors")
 
96
 
97
  self.sample_neg_prompt = config.sample_neg_prompt
98
 
99
+ if "Vace" in model_filename[-1]:
100
  self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
101
  min_area=480*832,
102
  max_area=480*832,
 
108
 
109
  self.adapt_vace_model()
110
 
111
+ def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = 0, overlap_noise = 0):
112
  if ref_images is None:
113
  ref_images = [None] * len(frames)
114
  else:
 
120
  inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
121
  reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
122
  inactive = self.vae.encode(inactive, tile_size = tile_size)
123
+ # inactive = [ t * (1.0 - noise_factor) + torch.randn_like(t ) * noise_factor for t in inactive]
124
+ # if overlapped_latents > 0:
125
+ # for t in inactive:
126
+ # t[:, :overlapped_latents ] = t[:, :overlapped_latents ] * (1.0 - noise_factor) + torch.randn_like(t[:, :overlapped_latents ] ) * noise_factor
127
+
128
  reactive = self.vae.encode(reactive, tile_size = tile_size)
129
  latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
130
 
 
294
  slg_end = 1.0,
295
  cfg_star_switch = True,
296
  cfg_zero_step = 5,
297
+ overlapped_latents = 0,
298
+ overlap_noise = 0,
299
+ vace = False
300
+ ):
301
  r"""
302
  Generates video frames from text prompt using diffusion process.
303
 
 
352
  size = (source_video.shape[2], source_video.shape[1])
353
  source_video = source_video.to(dtype=self.dtype , device=self.device)
354
  source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
355
+ source_latents = self.vae.encode([source_video])[0] #.to(dtype=self.dtype, device=self.device)
356
  del source_video
357
  # Process target camera (recammaster)
358
  from wan.utils.cammmaster_tools import get_camera_embedding
359
  cam_emb = get_camera_embedding(target_camera)
360
  cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
361
 
362
+ if vace :
363
  # vace context encode
364
  input_frames = [u.to(self.device) for u in input_frames]
365
  input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
366
  input_masks = [u.to(self.device) for u in input_masks]
367
 
368
+ z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents, overlap_noise = overlap_noise )
369
  m0 = self.vace_encode_masks(input_masks, input_ref_images)
370
  z = self.vace_latent(z0, m0)
371
 
 
374
  else:
375
  if input_ref_images != None: # Phantom Ref images
376
  phantom = True
377
+ input_ref_images = self.get_vae_latents(input_ref_images, self.device)
378
+ input_ref_images_neg = torch.zeros_like(input_ref_images)
379
  F = frame_num
380
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images.shape[1] if input_ref_images != None else 0),
381
  size[1] // self.vae_stride[1],
382
  size[0] // self.vae_stride[2])
383
 
 
414
  raise NotImplementedError("Unsupported solver.")
415
 
416
  # sample videos
417
+ latents = noise[0]
418
  del noise
419
+ batch_size = 1
420
  if target_camera != None:
421
+ shape = list(latents.shape[1:])
422
  shape[0] *= 2
423
  freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
424
  else:
425
+ freqs = get_rotary_pos_embed(latents.shape[1:], enable_RIFLEx= enable_RIFLEx)
426
 
427
  kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
428
 
429
  if target_camera != None:
430
  kwargs.update({'cam_emb': cam_emb})
431
 
432
+ if vace:
433
+ ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0
434
  kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
435
+ if overlapped_latents > 0:
436
+ z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z]
437
 
438
 
439
  if self.model.enable_teacache:
440
+ x_count = 3 if phantom else 2
441
+ self.model.previous_residual = [None] * x_count
442
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
443
  if callback != None:
444
  callback(-1, None, True)
445
  for i, t in enumerate(tqdm(timesteps)):
446
+ if vace and overlapped_latents > 0 :
447
+ # noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000
448
+ noise_factor = overlap_noise / 1000 # * (999-t) / 999
449
+ # noise_factor = overlap_noise / 1000 # * t / 999
450
+ for zz, zz_r in zip(z, z_reactive):
451
+ zz[0:16, ref_images_count:overlapped_latents + ref_images_count] = zz_r * (1.0 - noise_factor) + torch.randn_like(zz_r ) * noise_factor
452
+
453
  if target_camera != None:
454
+ latent_model_input = torch.cat([latents, source_latents], dim=1)
455
  else:
456
  latent_model_input = latents
457
+ kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
458
+
 
459
  timestep = [t]
460
  offload.set_step_no_for_lora(self.model, i)
461
  timestep = torch.stack(timestep)
 
464
  if joint_pass:
465
  if phantom:
466
  pos_it, pos_i, neg = self.model(
467
+ [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ] * 2 +
468
+ [ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1)],
469
  context = [context, context_null, context_null], **kwargs)
470
  else:
471
  noise_pred_cond, noise_pred_uncond = self.model(
472
+ [latent_model_input, latent_model_input], context = [context, context_null], **kwargs)
473
  if self._interrupt:
474
  return None
475
  else:
476
  if phantom:
477
  pos_it = self.model(
478
+ [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 0, context = [context], **kwargs
479
  )[0]
480
  if self._interrupt:
481
  return None
482
  pos_i = self.model(
483
+ [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 1, context = [context_null],**kwargs
484
  )[0]
485
  if self._interrupt:
486
  return None
487
  neg = self.model(
488
+ [ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1) ], x_id = 2, context = [context_null], **kwargs
489
  )[0]
490
  if self._interrupt:
491
  return None
492
  else:
493
  noise_pred_cond = self.model(
494
+ [latent_model_input], x_id = 0, context = [context], **kwargs)[0]
495
  if self._interrupt:
496
  return None
497
  noise_pred_uncond = self.model(
498
+ [latent_model_input], x_id = 1, context = [context_null], **kwargs)[0]
499
  if self._interrupt:
500
  return None
501
 
 
525
  temp_x0 = sample_scheduler.step(
526
  noise_pred[:, :target_shape[1]].unsqueeze(0),
527
  t,
528
+ latents.unsqueeze(0),
529
  return_dict=False,
530
  generator=seed_g)[0]
531
+ latents = temp_x0.squeeze(0)
532
  del temp_x0
533
 
534
  if callback is not None:
535
+ callback(i, latents, False)
536
 
537
+ x0 = [latents]
538
 
539
  if input_frames == None:
540
  if phantom:
541
  # phantom post processing
542
+ x0 = [x0_[:,:-input_ref_images.shape[1]] for x0_ in x0]
543
  videos = self.vae.decode(x0, VAE_tile_size)
544
  else:
545
  # vace post processing
wgp.py CHANGED
@@ -40,7 +40,7 @@ global_queue_ref = []
40
  AUTOSAVE_FILENAME = "queue.zip"
41
  PROMPT_VARS_MAX = 10
42
 
43
- target_mmgp_version = "3.4.1"
44
  from importlib.metadata import version
45
  mmgp_version = version("mmgp")
46
  if mmgp_version != target_mmgp_version:
@@ -49,32 +49,30 @@ if mmgp_version != target_mmgp_version:
49
  lock = threading.Lock()
50
  current_task_id = None
51
  task_id = 0
52
- # progress_tracker = {}
53
- # tracker_lock = threading.Lock()
54
-
55
- # def download_ffmpeg():
56
- # if os.name != 'nt': return
57
- # exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
58
- # if all(os.path.exists(e) for e in exes): return
59
- # api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
60
- # r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
61
- # assets = r.json().get('assets', [])
62
- # zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
63
- # if not zip_asset: return
64
- # zip_url = zip_asset['browser_download_url']
65
- # zip_name = zip_asset['name']
66
- # with requests.get(zip_url, stream=True) as resp:
67
- # total = int(resp.headers.get('Content-Length', 0))
68
- # with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
69
- # for chunk in resp.iter_content(chunk_size=8192):
70
- # f.write(chunk)
71
- # pbar.update(len(chunk))
72
- # with zipfile.ZipFile(zip_name) as z:
73
- # for f in z.namelist():
74
- # if f.endswith(tuple(exes)) and '/bin/' in f:
75
- # z.extract(f)
76
- # os.rename(f, os.path.basename(f))
77
- # os.remove(zip_name)
78
 
79
  def format_time(seconds):
80
  if seconds < 60:
@@ -168,14 +166,14 @@ def process_prompt_and_add_tasks(state, model_choice):
168
  resolution = inputs["resolution"]
169
  width, height = resolution.split("x")
170
  width, height = int(width), int(height)
171
- if test_class_i2v(model_filename):
172
- if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
173
- gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
174
- return
175
- resolution = str(width) + "*" + str(height)
176
- if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
177
- gr.Info(f"Resolution {resolution} not supported by image 2 video")
178
- return
179
 
180
  if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
181
  gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
@@ -533,7 +531,7 @@ def save_queue_action(state):
533
  task_id_s = task.get('id', f"task_{task_index}")
534
 
535
  image_keys = ["image_start", "image_end", "image_refs"]
536
- video_keys = ["video_guide", "video_mask", "video_source"]
537
 
538
  for key in image_keys:
539
  images_pil = params_copy.get(key)
@@ -707,7 +705,7 @@ def load_queue_action(filepath, state, evt:gr.EventData):
707
  params['state'] = state
708
 
709
  image_keys = ["image_start", "image_end", "image_refs"]
710
- video_keys = ["video_guide", "video_mask", "video_source"]
711
 
712
  loaded_pil_images = {}
713
  loaded_video_paths = {}
@@ -925,7 +923,7 @@ def autosave_queue():
925
  task_id_s = task.get('id', f"task_{task_index}")
926
 
927
  image_keys = ["image_start", "image_end", "image_refs"]
928
- video_keys = ["video_guide", "video_mask", "video_source"]
929
 
930
  for key in image_keys:
931
  images_pil = params_copy.get(key)
@@ -1418,32 +1416,35 @@ else:
1418
  text = reader.read()
1419
  server_config = json.loads(text)
1420
 
1421
- # for src_path, tgt_path in zip( ["ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"], ["ckpts/sky_reels2_diffusion_forcing_540p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_540p_14B_bf16.safetensors"] ):
1422
- # if Path(src_path).is_file():
1423
- # shutil.move(src_path, tgt_path) )
1424
- # for path in ["ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"]:
1425
- # if Path(path).is_file():
1426
- # os.remove(path)
 
1427
 
1428
- path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"
1429
- if os.path.isfile(path) and os.path.getsize(path) > 4000000000:
1430
- os.remove(path)
1431
 
1432
  transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
1433
  "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
1434
- "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors",
1435
  "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
1436
- transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors",
1437
- "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
1438
- "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
 
1439
  transformer_choices = transformer_choices_t2v + transformer_choices_i2v
1440
-
1441
- model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B"]
 
 
 
 
1442
  model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
1443
  "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
1444
  "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
1445
  "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
1446
- "phantom_1.3B" : "phantom_1.3B", }
1447
 
1448
 
1449
  def get_model_type(model_filename):
@@ -1453,7 +1454,7 @@ def get_model_type(model_filename):
1453
  raise Exception("Unknown model:" + model_filename)
1454
 
1455
  def test_class_i2v(model_filename):
1456
- return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
1457
 
1458
  def get_model_name(model_filename, description_container = [""]):
1459
  if "Fun" in model_filename:
@@ -1491,6 +1492,10 @@ def get_model_name(model_filename, description_container = [""]):
1491
  model_name = "Wan2.1 Phantom"
1492
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1493
  description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
 
 
 
 
1494
  else:
1495
  model_name = "Wan2.1 text2video"
1496
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
@@ -1536,6 +1541,7 @@ def get_default_settings(filename):
1536
  "repeat_generation": 1,
1537
  "multi_images_gen_type": 0,
1538
  "guidance_scale": 5.0,
 
1539
  "flow_shift": get_default_flow(filename, i2v),
1540
  "negative_prompt": "",
1541
  "activated_loras": [],
@@ -1719,8 +1725,9 @@ def download_models(transformer_filename, text_encoder_filename):
1719
 
1720
  from huggingface_hub import hf_hub_download, snapshot_download
1721
  repoId = "DeepBeepMeep/Wan2.1"
1722
- sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
1723
- fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
 
1724
  targetRoot = "ckpts/"
1725
  for sourceFolder, files in zip(sourceFolderList,fileList ):
1726
  if len(files)==0:
@@ -1834,12 +1841,13 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
1834
  return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
1835
 
1836
 
1837
- def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
1838
 
1839
  cfg = WAN_CONFIGS['t2v-14B']
 
1840
  # cfg = WAN_CONFIGS['t2v-1.3B']
1841
- print(f"Loading '{model_filename}' model...")
1842
- if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
1843
  model_factory = wan.DTT2V
1844
  else:
1845
  model_factory = wan.WanT2V
@@ -1859,9 +1867,10 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
1859
 
1860
  return wan_model, pipe
1861
 
1862
- def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
1863
 
1864
- print(f"Loading '{model_filename}' model...")
 
1865
 
1866
  cfg = WAN_CONFIGS['i2v-14B']
1867
  wan_model = wan.WanI2V(
@@ -1883,7 +1892,6 @@ def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = t
1883
  def load_models(model_filename):
1884
  global transformer_filename
1885
 
1886
- transformer_filename = model_filename
1887
  perc_reserved_mem_max = args.perc_reserved_mem_max
1888
 
1889
  major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
@@ -1892,19 +1900,27 @@ def load_models(model_filename):
1892
  default_dtype = torch.float16
1893
  else:
1894
  default_dtype = torch.float16 if args.fp16 else torch.bfloat16
1895
- if default_dtype == torch.float16 :
1896
- if "quanto" in model_filename:
1897
- model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
1898
- download_models(model_filename, text_encoder_filename)
 
 
 
 
 
 
 
1899
  VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
1900
  mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
1901
- if test_class_i2v(model_filename):
1902
- res720P = "720p" in model_filename
1903
- wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
 
1904
  else:
1905
- wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1906
- wan_model._model_file_name = model_filename
1907
- kwargs = { "extraModelsToQuantize": None}
1908
  if profile == 2 or profile == 4:
1909
  kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
1910
  # if profile == 4:
@@ -1914,7 +1930,7 @@ def load_models(model_filename):
1914
  offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
1915
  if len(args.gpu) > 0:
1916
  torch.set_default_device(args.gpu)
1917
-
1918
  return wan_model, offloadobj, pipe["transformer"]
1919
 
1920
  if not "P" in preload_model_policy:
@@ -2033,13 +2049,7 @@ def apply_changes( state,
2033
  preload_model_policy = server_config["preload_model_policy"]
2034
  transformer_quantization = server_config["transformer_quantization"]
2035
  transformer_types = server_config["transformer_types"]
2036
- model_filename = state["model_filename"]
2037
- model_transformer_type = get_model_type(model_filename)
2038
 
2039
- if not model_transformer_type in transformer_types:
2040
- model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
2041
- model_filename = get_model_filename(model_transformer_type, transformer_quantization)
2042
- state["model_filename"] = model_filename
2043
  if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
2044
  model_choice = gr.Dropdown()
2045
  else:
@@ -2249,7 +2259,9 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
2249
  frame_height, frame_width, _ = frames_list[0].shape
2250
 
2251
  if fit_canvas :
2252
- scale = min(height / frame_height, width / frame_width)
 
 
2253
  else:
2254
  scale = ((height * width ) / (frame_height * frame_width))**(1/2)
2255
 
@@ -2356,6 +2368,7 @@ def generate_video(
2356
  seed,
2357
  num_inference_steps,
2358
  guidance_scale,
 
2359
  flow_shift,
2360
  embedded_guidance_scale,
2361
  repeat_generation,
@@ -2375,8 +2388,10 @@ def generate_video(
2375
  video_guide,
2376
  keep_frames_video_guide,
2377
  video_mask,
 
2378
  sliding_window_size,
2379
  sliding_window_overlap,
 
2380
  sliding_window_discard_last_frames,
2381
  remove_background_image_ref,
2382
  temporal_upsampling,
@@ -2508,6 +2523,15 @@ def generate_video(
2508
  # VAE Tiling
2509
  device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
2510
 
 
 
 
 
 
 
 
 
 
2511
  joint_pass = boost ==1 #and profile != 1 and profile != 3
2512
  # TeaCache
2513
  trans.enable_teacache = tea_cache_setting > 0
@@ -2517,12 +2541,10 @@ def generate_video(
2517
  trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
2518
 
2519
  if image2video:
2520
- if '480p' in model_filename:
2521
- trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
2522
- elif '720p' in model_filename:
2523
  trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
2524
  else:
2525
- raise gr.Error("Teacache not supported for this model")
2526
  else:
2527
  if '1.3B' in model_filename:
2528
  trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
@@ -2535,6 +2557,18 @@ def generate_video(
2535
  if "recam" in model_filename:
2536
  source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
2537
  target_camera = model_mode
 
 
 
 
 
 
 
 
 
 
 
 
2538
  import random
2539
  if seed == None or seed <0:
2540
  seed = random.randint(0, 999999999)
@@ -2551,11 +2585,11 @@ def generate_video(
2551
  extra_generation = 0
2552
  initial_total_windows = 0
2553
  max_frames_to_generate = video_length
2554
- diffusion_forcing = "diffusion_forcing" in model_filename
2555
- vace = "Vace" in model_filename
2556
  phantom = "phantom" in model_filename
2557
  if diffusion_forcing or vace:
2558
  reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
 
 
2559
  if diffusion_forcing and source_video != None:
2560
  video_length += sliding_window_overlap
2561
  sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
@@ -2571,10 +2605,8 @@ def generate_video(
2571
  initial_total_windows = 1
2572
 
2573
  first_window_video_length = video_length
2574
- fps = 24 if diffusion_forcing else 16
2575
 
2576
  gen["sliding_window"] = sliding_window
2577
-
2578
  while not abort:
2579
  extra_generation += gen.get("extra_orders",0)
2580
  gen["extra_orders"] = 0
@@ -2594,6 +2626,7 @@ def generate_video(
2594
  guide_start_frame = 0
2595
  video_length = first_window_video_length
2596
  gen["extra_windows"] = 0
 
2597
  while not abort:
2598
  if sliding_window:
2599
  prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
@@ -2642,7 +2675,7 @@ def generate_video(
2642
 
2643
  if preprocess_type != None :
2644
  send_cmd("progress", progress_args)
2645
- video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, target_fps = fps)
2646
  keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
2647
  if len(error) > 0:
2648
  raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
@@ -2678,8 +2711,7 @@ def generate_video(
2678
  trans.teacache_counter = 0
2679
  trans.num_steps = num_inference_steps
2680
  trans.teacache_skipped_steps = 0
2681
- trans.previous_residual_uncond = None
2682
- trans.previous_residual_cond = None
2683
 
2684
  if image2video:
2685
  samples = wan_model.generate(
@@ -2687,7 +2719,9 @@ def generate_video(
2687
  image_start,
2688
  image_end if image_end != None else None,
2689
  frame_num=(video_length // 4)* 4 + 1,
2690
- max_area=MAX_AREA_CONFIGS[resolution_reformated],
 
 
2691
  shift=flow_shift,
2692
  sampling_steps=num_inference_steps,
2693
  guide_scale=guidance_scale,
@@ -2702,7 +2736,11 @@ def generate_video(
2702
  slg_end = slg_end_perc/100,
2703
  cfg_star_switch = cfg_star_switch,
2704
  cfg_zero_step = cfg_zero_step,
2705
- add_frames_for_end_image = "image2video" in model_filename
 
 
 
 
2706
  )
2707
  elif diffusion_forcing:
2708
  samples = wan_model.generate(
@@ -2720,14 +2758,17 @@ def generate_video(
2720
  callback= callback,
2721
  VAE_tile_size = VAE_tile_size,
2722
  joint_pass = joint_pass,
2723
- addnoise_condition = 20,
 
 
 
2724
  ar_step = model_mode, #5
2725
  causal_block_size = 5,
2726
  causal_attention = True,
2727
  fps = fps,
2728
  )
2729
  else:
2730
- samples = wan_model.generate(
2731
  prompt,
2732
  input_frames = src_video,
2733
  input_ref_images= src_ref_images,
@@ -2751,6 +2792,9 @@ def generate_video(
2751
  slg_end = slg_end_perc/100,
2752
  cfg_star_switch = cfg_star_switch,
2753
  cfg_zero_step = cfg_zero_step,
 
 
 
2754
  )
2755
  except Exception as e:
2756
  if temp_filename!= None and os.path.isfile(temp_filename):
@@ -2782,11 +2826,11 @@ def generate_video(
2782
  print('\n'.join(tb))
2783
  send_cmd("error", new_error)
2784
  return
 
 
2785
 
2786
  if trans.enable_teacache:
2787
- print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
2788
- trans.previous_residual_uncond = None
2789
- trans.previous_residual_cond = None
2790
 
2791
  if samples != None:
2792
  samples = samples.to("cpu")
@@ -2810,14 +2854,27 @@ def generate_video(
2810
  if discard_last_frames > 0:
2811
  sample = sample[: , :-discard_last_frames]
2812
  guide_start_frame -= discard_last_frames
2813
- pre_video_guide = sample[:, -reuse_frames:]
 
 
 
 
 
 
 
2814
  if prefix_video != None:
2815
- sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1)
 
 
 
2816
  prefix_video = None
2817
  if sliding_window and window_no > 1:
2818
- sample = sample[: , reuse_frames:]
2819
- guide_start_frame -= reuse_frames
 
 
2820
 
 
2821
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
2822
  if os.name == 'nt':
2823
  file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
@@ -2875,18 +2932,23 @@ def generate_video(
2875
  sample = torch.cat([frames_already_processed, sample], dim=1)
2876
  frames_already_processed = sample
2877
 
2878
- cache_video(
2879
- tensor=sample[None],
2880
- save_file=video_path,
2881
- fps=fps,
2882
- nrow=1,
2883
- normalize=True,
2884
- value_range=(-1, 1))
 
 
 
 
2885
 
2886
  inputs = get_function_arguments(generate_video, locals())
2887
  inputs.pop("send_cmd")
 
2888
  configs = prepare_inputs_dict("metadata", inputs)
2889
-
2890
  metadata_choice = server_config.get("metadata_type","metadata")
2891
  if metadata_choice == "json":
2892
  with open(video_path.replace('.mp4', '.json'), 'w') as f:
@@ -3113,7 +3175,7 @@ def get_latest_status(state):
3113
  prompt_no = gen["prompt_no"]
3114
  prompts_max = gen.get("prompts_max",0)
3115
  total_generation = gen.get("total_generation", 1)
3116
- repeat_no = gen["repeat_no"]
3117
  total_generation += gen.get("extra_orders", 0)
3118
  total_windows = gen.get("total_windows", 0)
3119
  total_windows += gen.get("extra_windows", 0)
@@ -3456,7 +3518,7 @@ def prepare_inputs_dict(target, inputs ):
3456
 
3457
  if target == "state":
3458
  return inputs
3459
- unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"]
3460
  for k in unsaved_params:
3461
  inputs.pop(k)
3462
 
@@ -3484,10 +3546,14 @@ def prepare_inputs_dict(target, inputs ):
3484
  inputs.pop(k)
3485
 
3486
  if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
3487
- unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_discard_last_frames"]
3488
  for k in unsaved_params:
3489
  inputs.pop(k)
3490
 
 
 
 
 
3491
  if target == "metadata":
3492
  inputs = {k: v for k,v in inputs.items() if v != None }
3493
 
@@ -3511,6 +3577,7 @@ def save_inputs(
3511
  seed,
3512
  num_inference_steps,
3513
  guidance_scale,
 
3514
  flow_shift,
3515
  embedded_guidance_scale,
3516
  repeat_generation,
@@ -3530,8 +3597,10 @@ def save_inputs(
3530
  video_guide,
3531
  keep_frames_video_guide,
3532
  video_mask,
 
3533
  sliding_window_size,
3534
  sliding_window_overlap,
 
3535
  sliding_window_discard_last_frames,
3536
  remove_background_image_ref,
3537
  temporal_upsampling,
@@ -3834,6 +3903,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3834
  recammaster = "recam" in model_filename
3835
  vace = "Vace" in model_filename
3836
  phantom = "phantom" in model_filename
 
3837
  with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
3838
  if diffusion_forcing:
3839
  image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
@@ -3939,7 +4009,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3939
 
3940
 
3941
  video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
3942
-
3943
 
3944
  advanced_prompt = advanced_ui
3945
  prompt_vars=[]
@@ -3972,12 +4042,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3972
  wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
3973
  wizard_variables_var = gr.Text(wizard_variables, visible = False)
3974
  with gr.Row():
3975
- if test_class_i2v(model_filename):
3976
  resolution = gr.Dropdown(
3977
  choices=[
3978
  # 720p
3979
- ("720p", "1280x720"),
3980
- ("480p", "832x480"),
3981
  ],
3982
  value=ui_defaults.get("resolution","480p"),
3983
  label="Resolution (video will have the same height / width ratio than the original image)"
@@ -3989,19 +4059,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3989
  ("1280x720 (16:9, 720p)", "1280x720"),
3990
  ("720x1280 (9:16, 720p)", "720x1280"),
3991
  ("1024x1024 (4:3, 720p)", "1024x024"),
3992
- # ("832x1104 (3:4, 720p)", "832x1104"),
3993
- # ("960x960 (1:1, 720p)", "960x960"),
 
3994
  # 480p
3995
  ("960x544 (16:9, 540p)", "960x544"),
3996
  ("544x960 (16:9, 540p)", "544x960"),
3997
  ("832x480 (16:9, 480p)", "832x480"),
3998
  ("480x832 (9:16, 480p)", "480x832"),
3999
- # ("832x624 (4:3, 540p)", "832x624"),
4000
- # ("624x832 (3:4, 540p)", "624x832"),
4001
- # ("720x720 (1:1, 540p)", "720x720"),
 
4002
  ],
4003
  value=ui_defaults.get("resolution","832x480"),
4004
- label="Resolution"
4005
  )
4006
  with gr.Row():
4007
  if recammaster:
@@ -4010,6 +4082,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
4010
  video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
4011
  elif vace:
4012
  video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
 
 
4013
  else:
4014
  video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
4015
  with gr.Row():
@@ -4029,10 +4103,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
4029
  choices=[
4030
  ("Generate every combination of images and texts", 0),
4031
  ("Match images and text prompts", 1),
4032
- ], visible= True, label= "Multiple Images as Texts Prompts"
4033
  )
4034
  with gr.Row():
4035
  guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
 
4036
  embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
4037
  flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
4038
  with gr.Row():
@@ -4099,7 +4174,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
4099
 
4100
  with gr.Tab("Quality"):
4101
  with gr.Row():
4102
- gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
4103
  with gr.Row():
4104
  slg_switch = gr.Dropdown(
4105
  choices=[
@@ -4148,11 +4223,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
4148
  if diffusion_forcing:
4149
  sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
4150
  sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
4151
- sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
 
4152
  else:
4153
  sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
4154
- sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
4155
- sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
 
4156
 
4157
 
4158
  with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
@@ -4167,8 +4244,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
4167
  label="RIFLEx positional embedding to generate long video"
4168
  )
4169
 
4170
- with gr.Row():
4171
- save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
4172
 
4173
  if not update_form:
4174
  with gr.Column():
@@ -5035,7 +5112,7 @@ def create_demo():
5035
  theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
5036
 
5037
  with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
5038
- gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
5039
  global model_list
5040
 
5041
  tab_state = gr.State({ "tab_no":0 })
@@ -5076,7 +5153,7 @@ def create_demo():
5076
 
5077
  if __name__ == "__main__":
5078
  atexit.register(autosave_queue)
5079
- # download_ffmpeg()
5080
  # threading.Thread(target=runner, daemon=True).start()
5081
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
5082
  server_port = int(args.server_port)
 
40
  AUTOSAVE_FILENAME = "queue.zip"
41
  PROMPT_VARS_MAX = 10
42
 
43
+ target_mmgp_version = "3.4.2"
44
  from importlib.metadata import version
45
  mmgp_version = version("mmgp")
46
  if mmgp_version != target_mmgp_version:
 
49
  lock = threading.Lock()
50
  current_task_id = None
51
  task_id = 0
52
+
53
+ def download_ffmpeg():
54
+ if os.name != 'nt': return
55
+ exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
56
+ if all(os.path.exists(e) for e in exes): return
57
+ api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
58
+ r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
59
+ assets = r.json().get('assets', [])
60
+ zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
61
+ if not zip_asset: return
62
+ zip_url = zip_asset['browser_download_url']
63
+ zip_name = zip_asset['name']
64
+ with requests.get(zip_url, stream=True) as resp:
65
+ total = int(resp.headers.get('Content-Length', 0))
66
+ with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
67
+ for chunk in resp.iter_content(chunk_size=8192):
68
+ f.write(chunk)
69
+ pbar.update(len(chunk))
70
+ with zipfile.ZipFile(zip_name) as z:
71
+ for f in z.namelist():
72
+ if f.endswith(tuple(exes)) and '/bin/' in f:
73
+ z.extract(f)
74
+ os.rename(f, os.path.basename(f))
75
+ os.remove(zip_name)
 
 
76
 
77
  def format_time(seconds):
78
  if seconds < 60:
 
166
  resolution = inputs["resolution"]
167
  width, height = resolution.split("x")
168
  width, height = int(width), int(height)
169
+ # if test_class_i2v(model_filename):
170
+ # if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
171
+ # gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
172
+ # return
173
+ # resolution = str(width) + "*" + str(height)
174
+ # if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
175
+ # gr.Info(f"Resolution {resolution} not supported by image 2 video")
176
+ # return
177
 
178
  if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
179
  gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
 
531
  task_id_s = task.get('id', f"task_{task_index}")
532
 
533
  image_keys = ["image_start", "image_end", "image_refs"]
534
+ video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
535
 
536
  for key in image_keys:
537
  images_pil = params_copy.get(key)
 
705
  params['state'] = state
706
 
707
  image_keys = ["image_start", "image_end", "image_refs"]
708
+ video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
709
 
710
  loaded_pil_images = {}
711
  loaded_video_paths = {}
 
923
  task_id_s = task.get('id', f"task_{task_index}")
924
 
925
  image_keys = ["image_start", "image_end", "image_refs"]
926
+ video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
927
 
928
  for key in image_keys:
929
  images_pil = params_copy.get(key)
 
1416
  text = reader.read()
1417
  server_config = json.loads(text)
1418
 
1419
+ # Deprecated models
1420
+ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors",
1421
+ "sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors",
1422
+ "wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
1423
+ ]:
1424
+ if Path(os.path.join("ckpts" , path)).is_file():
1425
+ os.remove( os.path.join("ckpts" , path))
1426
 
 
 
 
1427
 
1428
  transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
1429
  "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
1430
+ "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors",
1431
  "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
1432
+ transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors",
1433
+ "ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
1434
+ "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
1435
+ "ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"]
1436
  transformer_choices = transformer_choices_t2v + transformer_choices_i2v
1437
+ def get_dependent_models(model_filename, quantization ):
1438
+ if "fantasy" in model_filename:
1439
+ return [get_model_filename("i2v_720p", quantization)]
1440
+ else:
1441
+ return []
1442
+ model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy"]
1443
  model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
1444
  "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
1445
  "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
1446
  "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
1447
+ "phantom_1.3B" : "phantom_1.3B", "fantasy" : "fantasy" }
1448
 
1449
 
1450
  def get_model_type(model_filename):
 
1454
  raise Exception("Unknown model:" + model_filename)
1455
 
1456
  def test_class_i2v(model_filename):
1457
+ return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename
1458
 
1459
  def get_model_name(model_filename, description_container = [""]):
1460
  if "Fun" in model_filename:
 
1492
  model_name = "Wan2.1 Phantom"
1493
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1494
  description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
1495
+ elif "fantasy" in model_filename:
1496
+ model_name = "Wan2.1 Fantasy Speaking 720p"
1497
+ model_name += " 14B" if "14B" in model_filename else " 1.3B"
1498
+ description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input."
1499
  else:
1500
  model_name = "Wan2.1 text2video"
1501
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
 
1541
  "repeat_generation": 1,
1542
  "multi_images_gen_type": 0,
1543
  "guidance_scale": 5.0,
1544
+ "audio_guidance_scale": 5.0,
1545
  "flow_shift": get_default_flow(filename, i2v),
1546
  "negative_prompt": "",
1547
  "activated_loras": [],
 
1725
 
1726
  from huggingface_hub import hf_hub_download, snapshot_download
1727
  repoId = "DeepBeepMeep/Wan2.1"
1728
+ sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "wav2vec", "" ]
1729
+ fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
1730
+ ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
1731
  targetRoot = "ckpts/"
1732
  for sourceFolder, files in zip(sourceFolderList,fileList ):
1733
  if len(files)==0:
 
1841
  return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
1842
 
1843
 
1844
+ def load_t2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
1845
 
1846
  cfg = WAN_CONFIGS['t2v-14B']
1847
+ filename = model_filename[-1]
1848
  # cfg = WAN_CONFIGS['t2v-1.3B']
1849
+ print(f"Loading '{filename}' model...")
1850
+ if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
1851
  model_factory = wan.DTT2V
1852
  else:
1853
  model_factory = wan.WanT2V
 
1867
 
1868
  return wan_model, pipe
1869
 
1870
+ def load_i2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
1871
 
1872
+ filename = model_filename[-1]
1873
+ print(f"Loading '{filename}' model...")
1874
 
1875
  cfg = WAN_CONFIGS['i2v-14B']
1876
  wan_model = wan.WanI2V(
 
1892
  def load_models(model_filename):
1893
  global transformer_filename
1894
 
 
1895
  perc_reserved_mem_max = args.perc_reserved_mem_max
1896
 
1897
  major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
 
1900
  default_dtype = torch.float16
1901
  else:
1902
  default_dtype = torch.float16 if args.fp16 else torch.bfloat16
1903
+ model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
1904
+ updated_model_filename = []
1905
+ for filename in model_filelist:
1906
+ if default_dtype == torch.float16 :
1907
+ if "quanto_int8" in filename:
1908
+ filename = filename.replace("quanto_int8", "quanto_fp16_int8")
1909
+ elif "quanto_mbf16_int8":
1910
+ filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
1911
+ updated_model_filename.append(filename)
1912
+ download_models(filename, text_encoder_filename)
1913
+ model_filelist = updated_model_filename
1914
  VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
1915
  mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
1916
+ transformer_filename = None
1917
+ new_transformer_filename = model_filelist[-1]
1918
+ if test_class_i2v(new_transformer_filename):
1919
+ wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1920
  else:
1921
+ wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1922
+ wan_model._model_file_name = new_transformer_filename
1923
+ kwargs = { "extraModelsToQuantize": None}
1924
  if profile == 2 or profile == 4:
1925
  kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
1926
  # if profile == 4:
 
1930
  offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
1931
  if len(args.gpu) > 0:
1932
  torch.set_default_device(args.gpu)
1933
+ transformer_filename = new_transformer_filename
1934
  return wan_model, offloadobj, pipe["transformer"]
1935
 
1936
  if not "P" in preload_model_policy:
 
2049
  preload_model_policy = server_config["preload_model_policy"]
2050
  transformer_quantization = server_config["transformer_quantization"]
2051
  transformer_types = server_config["transformer_types"]
 
 
2052
 
 
 
 
 
2053
  if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
2054
  model_choice = gr.Dropdown()
2055
  else:
 
2259
  frame_height, frame_width, _ = frames_list[0].shape
2260
 
2261
  if fit_canvas :
2262
+ scale1 = min(height / frame_height, width / frame_width)
2263
+ scale2 = min(height / frame_width, width / frame_height)
2264
+ scale = max(scale1, scale2)
2265
  else:
2266
  scale = ((height * width ) / (frame_height * frame_width))**(1/2)
2267
 
 
2368
  seed,
2369
  num_inference_steps,
2370
  guidance_scale,
2371
+ audio_guidance_scale,
2372
  flow_shift,
2373
  embedded_guidance_scale,
2374
  repeat_generation,
 
2388
  video_guide,
2389
  keep_frames_video_guide,
2390
  video_mask,
2391
+ audio_guide,
2392
  sliding_window_size,
2393
  sliding_window_overlap,
2394
+ sliding_window_overlap_noise,
2395
  sliding_window_discard_last_frames,
2396
  remove_background_image_ref,
2397
  temporal_upsampling,
 
2523
  # VAE Tiling
2524
  device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
2525
 
2526
+ diffusion_forcing = "diffusion_forcing" in model_filename
2527
+ vace = "Vace" in model_filename
2528
+ if diffusion_forcing:
2529
+ fps = 24
2530
+ elif audio_guide != None:
2531
+ fps = 23
2532
+ else:
2533
+ fps = 16
2534
+
2535
  joint_pass = boost ==1 #and profile != 1 and profile != 3
2536
  # TeaCache
2537
  trans.enable_teacache = tea_cache_setting > 0
 
2541
  trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
2542
 
2543
  if image2video:
2544
+ if '720p' in model_filename:
 
 
2545
  trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
2546
  else:
2547
+ trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
2548
  else:
2549
  if '1.3B' in model_filename:
2550
  trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
 
2557
  if "recam" in model_filename:
2558
  source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
2559
  target_camera = model_mode
2560
+
2561
+ audio_proj_split = None
2562
+ audio_scale = None
2563
+ audio_context_lens = None
2564
+ if audio_guide != None:
2565
+ from fantasytalking.infer import parse_audio
2566
+ import librosa
2567
+ duration = librosa.get_duration(path=audio_guide)
2568
+ video_length = min(int(fps * duration // 4) * 4 + 5, video_length)
2569
+ audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= video_length, fps= fps, device= processing_device )
2570
+ audio_scale = 1.0
2571
+
2572
  import random
2573
  if seed == None or seed <0:
2574
  seed = random.randint(0, 999999999)
 
2585
  extra_generation = 0
2586
  initial_total_windows = 0
2587
  max_frames_to_generate = video_length
 
 
2588
  phantom = "phantom" in model_filename
2589
  if diffusion_forcing or vace:
2590
  reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
2591
+ else:
2592
+ reuse_frames = 0
2593
  if diffusion_forcing and source_video != None:
2594
  video_length += sliding_window_overlap
2595
  sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
 
2605
  initial_total_windows = 1
2606
 
2607
  first_window_video_length = video_length
 
2608
 
2609
  gen["sliding_window"] = sliding_window
 
2610
  while not abort:
2611
  extra_generation += gen.get("extra_orders",0)
2612
  gen["extra_orders"] = 0
 
2626
  guide_start_frame = 0
2627
  video_length = first_window_video_length
2628
  gen["extra_windows"] = 0
2629
+ start_time = time.time()
2630
  while not abort:
2631
  if sliding_window:
2632
  prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
 
2675
 
2676
  if preprocess_type != None :
2677
  send_cmd("progress", progress_args)
2678
+ video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps)
2679
  keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
2680
  if len(error) > 0:
2681
  raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
 
2711
  trans.teacache_counter = 0
2712
  trans.num_steps = num_inference_steps
2713
  trans.teacache_skipped_steps = 0
2714
+ trans.previous_residual = None
 
2715
 
2716
  if image2video:
2717
  samples = wan_model.generate(
 
2719
  image_start,
2720
  image_end if image_end != None else None,
2721
  frame_num=(video_length // 4)* 4 + 1,
2722
+ # max_area=MAX_AREA_CONFIGS[resolution_reformated],
2723
+ height = height,
2724
+ width = width,
2725
  shift=flow_shift,
2726
  sampling_steps=num_inference_steps,
2727
  guide_scale=guidance_scale,
 
2736
  slg_end = slg_end_perc/100,
2737
  cfg_star_switch = cfg_star_switch,
2738
  cfg_zero_step = cfg_zero_step,
2739
+ add_frames_for_end_image = "image2video" in model_filename,
2740
+ audio_cfg_scale= audio_guidance_scale,
2741
+ audio_proj= audio_proj_split,
2742
+ audio_scale= audio_scale,
2743
+ audio_context_lens= audio_context_lens
2744
  )
2745
  elif diffusion_forcing:
2746
  samples = wan_model.generate(
 
2758
  callback= callback,
2759
  VAE_tile_size = VAE_tile_size,
2760
  joint_pass = joint_pass,
2761
+ slg_layers = slg_layers,
2762
+ slg_start = slg_start_perc/100,
2763
+ slg_end = slg_end_perc/100,
2764
+ addnoise_condition = sliding_window_overlap_noise,
2765
  ar_step = model_mode, #5
2766
  causal_block_size = 5,
2767
  causal_attention = True,
2768
  fps = fps,
2769
  )
2770
  else:
2771
+ samples = wan_model.generate(
2772
  prompt,
2773
  input_frames = src_video,
2774
  input_ref_images= src_ref_images,
 
2792
  slg_end = slg_end_perc/100,
2793
  cfg_star_switch = cfg_star_switch,
2794
  cfg_zero_step = cfg_zero_step,
2795
+ overlapped_latents = 0 if reuse_frames == 0 or window_no == 1 else ((reuse_frames - 1) // 4 + 1),
2796
+ overlap_noise = sliding_window_overlap_noise,
2797
+ vace = vace
2798
  )
2799
  except Exception as e:
2800
  if temp_filename!= None and os.path.isfile(temp_filename):
 
2826
  print('\n'.join(tb))
2827
  send_cmd("error", new_error)
2828
  return
2829
+ finally:
2830
+ trans.previous_residual = None
2831
 
2832
  if trans.enable_teacache:
2833
+ print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
 
 
2834
 
2835
  if samples != None:
2836
  samples = samples.to("cpu")
 
2854
  if discard_last_frames > 0:
2855
  sample = sample[: , :-discard_last_frames]
2856
  guide_start_frame -= discard_last_frames
2857
+ if reuse_frames == 0:
2858
+ pre_video_guide = sample[:,9999 :]
2859
+ else:
2860
+ # noise_factor = 200/ 1000
2861
+ # pre_video_guide = sample[:, -reuse_frames:] * (1.0 - noise_factor) + torch.randn_like(sample[:, -reuse_frames:] ) * noise_factor
2862
+ pre_video_guide = sample[:, -reuse_frames:]
2863
+
2864
+
2865
  if prefix_video != None:
2866
+ if reuse_frames == 0:
2867
+ sample = torch.cat([ prefix_video[:, :], sample], dim = 1)
2868
+ else:
2869
+ sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1)
2870
  prefix_video = None
2871
  if sliding_window and window_no > 1:
2872
+ if reuse_frames == 0:
2873
+ sample = sample[: , :]
2874
+ else:
2875
+ sample = sample[: , reuse_frames:]
2876
 
2877
+ guide_start_frame -= reuse_frames
2878
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
2879
  if os.name == 'nt':
2880
  file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
 
2932
  sample = torch.cat([frames_already_processed, sample], dim=1)
2933
  frames_already_processed = sample
2934
 
2935
+ if audio_guide == None:
2936
+ cache_video( tensor=sample[None], save_file=video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
2937
+ else:
2938
+ save_path_tmp = video_path[:-4] + "_tmp.mp4"
2939
+ cache_video( tensor=sample[None], save_file=save_path_tmp, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
2940
+ final_command = [ "ffmpeg", "-y", "-i", save_path_tmp, "-i", audio_guide, "-c:v", "libx264", "-c:a", "aac", "-shortest", "-loglevel", "warning", "-nostats", video_path, ]
2941
+ import subprocess
2942
+ subprocess.run(final_command, check=True)
2943
+ os.remove(save_path_tmp)
2944
+
2945
+ end_time = time.time()
2946
 
2947
  inputs = get_function_arguments(generate_video, locals())
2948
  inputs.pop("send_cmd")
2949
+ inputs.pop("task_id")
2950
  configs = prepare_inputs_dict("metadata", inputs)
2951
+ configs["generation_time"] = round(end_time-start_time)
2952
  metadata_choice = server_config.get("metadata_type","metadata")
2953
  if metadata_choice == "json":
2954
  with open(video_path.replace('.mp4', '.json'), 'w') as f:
 
3175
  prompt_no = gen["prompt_no"]
3176
  prompts_max = gen.get("prompts_max",0)
3177
  total_generation = gen.get("total_generation", 1)
3178
+ repeat_no = gen.get("repeat_no",0)
3179
  total_generation += gen.get("extra_orders", 0)
3180
  total_windows = gen.get("total_windows", 0)
3181
  total_windows += gen.get("extra_windows", 0)
 
3518
 
3519
  if target == "state":
3520
  return inputs
3521
+ unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask", "audio_guide", "embedded_guidance_scale"]
3522
  for k in unsaved_params:
3523
  inputs.pop(k)
3524
 
 
3546
  inputs.pop(k)
3547
 
3548
  if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
3549
+ unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]
3550
  for k in unsaved_params:
3551
  inputs.pop(k)
3552
 
3553
+ if not "fantasy" in model_filename:
3554
+ inputs.pop("audio_guidance_scale")
3555
+
3556
+
3557
  if target == "metadata":
3558
  inputs = {k: v for k,v in inputs.items() if v != None }
3559
 
 
3577
  seed,
3578
  num_inference_steps,
3579
  guidance_scale,
3580
+ audio_guidance_scale,
3581
  flow_shift,
3582
  embedded_guidance_scale,
3583
  repeat_generation,
 
3597
  video_guide,
3598
  keep_frames_video_guide,
3599
  video_mask,
3600
+ audio_guide,
3601
  sliding_window_size,
3602
  sliding_window_overlap,
3603
+ sliding_window_overlap_noise,
3604
  sliding_window_discard_last_frames,
3605
  remove_background_image_ref,
3606
  temporal_upsampling,
 
3903
  recammaster = "recam" in model_filename
3904
  vace = "Vace" in model_filename
3905
  phantom = "phantom" in model_filename
3906
+ fantasy = "fantasy" in model_filename
3907
  with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
3908
  if diffusion_forcing:
3909
  image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
 
4009
 
4010
 
4011
  video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
4012
+ audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy )
4013
 
4014
  advanced_prompt = advanced_ui
4015
  prompt_vars=[]
 
4042
  wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
4043
  wizard_variables_var = gr.Text(wizard_variables, visible = False)
4044
  with gr.Row():
4045
+ if test_class_i2v(model_filename) and False:
4046
  resolution = gr.Dropdown(
4047
  choices=[
4048
  # 720p
4049
+ ("720p (same amount of pixels)", "1280x720"),
4050
+ ("480p (same amount of pixels)", "832x480"),
4051
  ],
4052
  value=ui_defaults.get("resolution","480p"),
4053
  label="Resolution (video will have the same height / width ratio than the original image)"
 
4059
  ("1280x720 (16:9, 720p)", "1280x720"),
4060
  ("720x1280 (9:16, 720p)", "720x1280"),
4061
  ("1024x1024 (4:3, 720p)", "1024x024"),
4062
+ ("832x1104 (3:4, 720p)", "832x1104"),
4063
+ ("1104x832 (3:4, 720p)", "1104x832"),
4064
+ ("960x960 (1:1, 720p)", "960x960"),
4065
  # 480p
4066
  ("960x544 (16:9, 540p)", "960x544"),
4067
  ("544x960 (16:9, 540p)", "544x960"),
4068
  ("832x480 (16:9, 480p)", "832x480"),
4069
  ("480x832 (9:16, 480p)", "480x832"),
4070
+ ("832x624 (4:3, 480p)", "832x624"),
4071
+ ("624x832 (3:4, 480p)", "624x832"),
4072
+ ("720x720 (1:1, 480p)", "720x720"),
4073
+ ("512x512 (1:1, 480p)", "512x512"),
4074
  ],
4075
  value=ui_defaults.get("resolution","832x480"),
4076
+ label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution"
4077
  )
4078
  with gr.Row():
4079
  if recammaster:
 
4082
  video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
4083
  elif vace:
4084
  video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
4085
+ elif fantasy:
4086
+ video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True)
4087
  else:
4088
  video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
4089
  with gr.Row():
 
4103
  choices=[
4104
  ("Generate every combination of images and texts", 0),
4105
  ("Match images and text prompts", 1),
4106
+ ], visible= test_class_i2v(model_filename), label= "Multiple Images as Texts Prompts"
4107
  )
4108
  with gr.Row():
4109
  guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
4110
+ audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy)
4111
  embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
4112
  flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
4113
  with gr.Row():
 
4174
 
4175
  with gr.Tab("Quality"):
4176
  with gr.Row():
4177
+ gr.Markdown("<B>Skip Layer Guidance (improves video quality)</B>")
4178
  with gr.Row():
4179
  slg_switch = gr.Dropdown(
4180
  choices=[
 
4223
  if diffusion_forcing:
4224
  sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
4225
  sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
4226
+ sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
4227
+ sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
4228
  else:
4229
  sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
4230
+ sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
4231
+ sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
4232
+ sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
4233
 
4234
 
4235
  with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
 
4244
  label="RIFLEx positional embedding to generate long video"
4245
  )
4246
 
4247
+ with gr.Row():
4248
+ save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
4249
 
4250
  if not update_form:
4251
  with gr.Column():
 
5112
  theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
5113
 
5114
  with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
5115
+ gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.5 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
5116
  global model_list
5117
 
5118
  tab_state = gr.State({ "tab_no":0 })
 
5153
 
5154
  if __name__ == "__main__":
5155
  atexit.register(autosave_queue)
5156
+ download_ffmpeg()
5157
  # threading.Thread(target=runner, daemon=True).start()
5158
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
5159
  server_port = int(args.server_port)