Spaces:
Running
on
T4
Running
on
T4
DeepBeepMeep
commited on
Commit
·
453823e
1
Parent(s):
068831e
Added support for fantasyspeaking model
Browse files- README.md +8 -1
- fantasytalking/infer.py +27 -0
- fantasytalking/model.py +162 -0
- fantasytalking/utils.py +52 -0
- requirements.txt +2 -1
- wan/configs/__init__.py +2 -0
- wan/diffusion_forcing.py +36 -20
- wan/image2video.py +99 -69
- wan/modules/attention.py +69 -81
- wan/modules/model.py +139 -104
- wan/modules/sage2_core.py +1 -1
- wan/text2video.py +53 -33
- wgp.py +210 -133
README.md
CHANGED
@@ -10,6 +10,7 @@
|
|
10 |
|
11 |
|
12 |
## 🔥 Latest News!!
|
|
|
13 |
* April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
|
14 |
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
|
15 |
|
@@ -303,7 +304,13 @@ Vace provides on its github (https://github.com/ali-vilab/VACE/tree/main/vace/gr
|
|
303 |
|
304 |
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
|
305 |
|
306 |
-
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
|
309 |
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
|
|
|
10 |
|
11 |
|
12 |
## 🔥 Latest News!!
|
13 |
+
* May 5 2025: 👋 Wan 2.1GP v4.5: FantasySpeaking model, you can animate a talking head using a voice track. This works not only on people but also on objects. Also better seamless transitions between Vace sliding windows for very long videos (see recommended settings). New high quality processing features (mixed 16/32 bits calculation and 32 bitsVAE)
|
14 |
* April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
|
15 |
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
|
16 |
|
|
|
304 |
|
305 |
There is also a guide that describes the various combination of hints (https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).Good luck !
|
306 |
|
307 |
+
It seems you will get better results with Vace if you turn on "Skip Layer Guidance" with its default configuration.
|
308 |
+
|
309 |
+
Other recommended setttings for Vace:
|
310 |
+
- Use a long prompt description especially for the people / objects that are in the background and not in reference images. This will ensure consistency between the windows.
|
311 |
+
- Set a medium size overlap window: long enough to give the model a sense of the motion but short enough so any overlapped blurred frames do no turn the rest of the video into a blurred video
|
312 |
+
- Truncate at least the last 4 frames of the each generated window as Vace last frames tends to be blurry
|
313 |
+
|
314 |
|
315 |
### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
|
316 |
With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
|
fantasytalking/infer.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Alibaba Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Processor
|
4 |
+
|
5 |
+
from .model import FantasyTalkingAudioConditionModel
|
6 |
+
from .utils import get_audio_features
|
7 |
+
|
8 |
+
|
9 |
+
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
|
10 |
+
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
|
11 |
+
from mmgp import offload
|
12 |
+
from accelerate import init_empty_weights
|
13 |
+
from fantasytalking.model import AudioProjModel
|
14 |
+
with init_empty_weights():
|
15 |
+
proj_model = AudioProjModel( 768, 2048)
|
16 |
+
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
|
17 |
+
proj_model.to(device).eval().requires_grad_(False)
|
18 |
+
|
19 |
+
wav2vec_model_dir = "ckpts/wav2vec"
|
20 |
+
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
|
21 |
+
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False)
|
22 |
+
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
|
23 |
+
|
24 |
+
audio_proj_fea = proj_model(audio_wav2vec_fea)
|
25 |
+
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
|
26 |
+
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
|
27 |
+
return audio_proj_split, audio_context_lens
|
fantasytalking/model.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from wan.modules.attention import pay_attention
|
5 |
+
|
6 |
+
|
7 |
+
class AudioProjModel(nn.Module):
|
8 |
+
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
|
9 |
+
super().__init__()
|
10 |
+
self.cross_attention_dim = cross_attention_dim
|
11 |
+
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
|
12 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
13 |
+
|
14 |
+
def forward(self, audio_embeds):
|
15 |
+
context_tokens = self.proj(audio_embeds)
|
16 |
+
context_tokens = self.norm(context_tokens)
|
17 |
+
return context_tokens # [B,L,C]
|
18 |
+
|
19 |
+
class WanCrossAttentionProcessor(nn.Module):
|
20 |
+
def __init__(self, context_dim, hidden_dim):
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
self.context_dim = context_dim
|
24 |
+
self.hidden_dim = hidden_dim
|
25 |
+
|
26 |
+
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
27 |
+
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
28 |
+
|
29 |
+
nn.init.zeros_(self.k_proj.weight)
|
30 |
+
nn.init.zeros_(self.v_proj.weight)
|
31 |
+
|
32 |
+
def __call__(
|
33 |
+
self,
|
34 |
+
q: torch.Tensor,
|
35 |
+
audio_proj: torch.Tensor,
|
36 |
+
latents_num_frames: int = 21,
|
37 |
+
audio_context_lens = None
|
38 |
+
) -> torch.Tensor:
|
39 |
+
"""
|
40 |
+
audio_proj: [B, 21, L3, C]
|
41 |
+
audio_context_lens: [B*21].
|
42 |
+
"""
|
43 |
+
b, l, n, d = q.shape
|
44 |
+
|
45 |
+
if len(audio_proj.shape) == 4:
|
46 |
+
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
|
47 |
+
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
48 |
+
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
49 |
+
qkv_list = [audio_q, ip_key, ip_value]
|
50 |
+
del q, audio_q, ip_key, ip_value
|
51 |
+
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
|
52 |
+
audio_x = audio_x.view(b, l, n, d)
|
53 |
+
audio_x = audio_x.flatten(2)
|
54 |
+
elif len(audio_proj.shape) == 3:
|
55 |
+
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
|
56 |
+
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
|
57 |
+
qkv_list = [q, ip_key, ip_value]
|
58 |
+
del q, ip_key, ip_value
|
59 |
+
audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
|
60 |
+
audio_x = audio_x.flatten(2)
|
61 |
+
return audio_x
|
62 |
+
|
63 |
+
|
64 |
+
class FantasyTalkingAudioConditionModel(nn.Module):
|
65 |
+
def __init__(self, wan_dit, audio_in_dim: int, audio_proj_dim: int):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.audio_in_dim = audio_in_dim
|
69 |
+
self.audio_proj_dim = audio_proj_dim
|
70 |
+
|
71 |
+
def split_audio_sequence(self, audio_proj_length, num_frames=81):
|
72 |
+
"""
|
73 |
+
Map the audio feature sequence to corresponding latent frame slices.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
audio_proj_length (int): The total length of the audio feature sequence
|
77 |
+
(e.g., 173 in audio_proj[1, 173, 768]).
|
78 |
+
num_frames (int): The number of video frames in the training data (default: 81).
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
|
82 |
+
(within the audio feature sequence) corresponding to a latent frame.
|
83 |
+
"""
|
84 |
+
# Average number of tokens per original video frame
|
85 |
+
tokens_per_frame = audio_proj_length / num_frames
|
86 |
+
|
87 |
+
# Each latent frame covers 4 video frames, and we want the center
|
88 |
+
tokens_per_latent_frame = tokens_per_frame * 4
|
89 |
+
half_tokens = int(tokens_per_latent_frame / 2)
|
90 |
+
|
91 |
+
pos_indices = []
|
92 |
+
for i in range(int((num_frames - 1) / 4) + 1):
|
93 |
+
if i == 0:
|
94 |
+
pos_indices.append(0)
|
95 |
+
else:
|
96 |
+
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
|
97 |
+
end_token = tokens_per_frame * (i * 4 + 1)
|
98 |
+
center_token = int((start_token + end_token) / 2) - 1
|
99 |
+
pos_indices.append(center_token)
|
100 |
+
|
101 |
+
# Build index ranges centered around each position
|
102 |
+
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
|
103 |
+
|
104 |
+
# Adjust the first range to avoid negative start index
|
105 |
+
pos_idx_ranges[0] = [
|
106 |
+
-(half_tokens * 2 - pos_idx_ranges[1][0]),
|
107 |
+
pos_idx_ranges[1][0],
|
108 |
+
]
|
109 |
+
|
110 |
+
return pos_idx_ranges
|
111 |
+
|
112 |
+
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
|
113 |
+
"""
|
114 |
+
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
|
115 |
+
if the range exceeds the input boundaries.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
|
119 |
+
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
|
120 |
+
expand_length (int): Number of tokens to expand on both sides of each subsequence.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
|
124 |
+
Each element is a padded subsequence.
|
125 |
+
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
|
126 |
+
Useful for ignoring padding tokens in attention masks.
|
127 |
+
"""
|
128 |
+
pos_idx_ranges = [
|
129 |
+
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
|
130 |
+
]
|
131 |
+
sub_sequences = []
|
132 |
+
seq_len = input_tensor.size(1) # 173
|
133 |
+
max_valid_idx = seq_len - 1 # 172
|
134 |
+
k_lens_list = []
|
135 |
+
for start, end in pos_idx_ranges:
|
136 |
+
# Calculate the fill amount
|
137 |
+
pad_front = max(-start, 0)
|
138 |
+
pad_back = max(end - max_valid_idx, 0)
|
139 |
+
|
140 |
+
# Calculate the start and end indices of the valid part
|
141 |
+
valid_start = max(start, 0)
|
142 |
+
valid_end = min(end, max_valid_idx)
|
143 |
+
|
144 |
+
# Extract the valid part
|
145 |
+
if valid_start <= valid_end:
|
146 |
+
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
|
147 |
+
else:
|
148 |
+
valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
|
149 |
+
|
150 |
+
# In the sequence dimension (the 1st dimension) perform padding
|
151 |
+
padded_subseq = F.pad(
|
152 |
+
valid_part,
|
153 |
+
(0, 0, 0, pad_back + pad_front, 0, 0),
|
154 |
+
mode="constant",
|
155 |
+
value=0,
|
156 |
+
)
|
157 |
+
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
|
158 |
+
|
159 |
+
sub_sequences.append(padded_subseq)
|
160 |
+
return torch.stack(sub_sequences, dim=1), torch.tensor(
|
161 |
+
k_lens_list, dtype=torch.long
|
162 |
+
)
|
fantasytalking/utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Alibaba Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
import imageio
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def resize_image_by_longest_edge(image_path, target_size):
|
12 |
+
image = Image.open(image_path).convert("RGB")
|
13 |
+
width, height = image.size
|
14 |
+
scale = target_size / max(width, height)
|
15 |
+
new_size = (int(width * scale), int(height * scale))
|
16 |
+
return image.resize(new_size, Image.LANCZOS)
|
17 |
+
|
18 |
+
|
19 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
20 |
+
writer = imageio.get_writer(
|
21 |
+
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
|
22 |
+
)
|
23 |
+
for frame in tqdm(frames, desc="Saving video"):
|
24 |
+
frame = np.array(frame)
|
25 |
+
writer.append_data(frame)
|
26 |
+
writer.close()
|
27 |
+
|
28 |
+
|
29 |
+
def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
|
30 |
+
sr = 16000
|
31 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
|
32 |
+
|
33 |
+
start_time = 0
|
34 |
+
# end_time = (0 + (num_frames - 1) * 1) / fps
|
35 |
+
end_time = num_frames / fps
|
36 |
+
|
37 |
+
start_sample = int(start_time * sr)
|
38 |
+
end_sample = int(end_time * sr)
|
39 |
+
|
40 |
+
try:
|
41 |
+
audio_segment = audio_input[start_sample:end_sample]
|
42 |
+
except:
|
43 |
+
audio_segment = audio_input
|
44 |
+
|
45 |
+
input_values = audio_processor(
|
46 |
+
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
|
47 |
+
).input_values.to("cuda")
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
fea = wav2vec(input_values).last_hidden_state
|
51 |
+
|
52 |
+
return fea
|
requirements.txt
CHANGED
@@ -16,7 +16,7 @@ gradio==5.23.0
|
|
16 |
numpy>=1.23.5,<2
|
17 |
einops
|
18 |
moviepy==1.0.3
|
19 |
-
mmgp==3.4.
|
20 |
peft==0.14.0
|
21 |
mutagen
|
22 |
pydantic==2.10.6
|
@@ -28,4 +28,5 @@ timm
|
|
28 |
segment-anything
|
29 |
omegaconf
|
30 |
hydra-core
|
|
|
31 |
# rembg==2.0.65
|
|
|
16 |
numpy>=1.23.5,<2
|
17 |
einops
|
18 |
moviepy==1.0.3
|
19 |
+
mmgp==3.4.2
|
20 |
peft==0.14.0
|
21 |
mutagen
|
22 |
pydantic==2.10.6
|
|
|
28 |
segment-anything
|
29 |
omegaconf
|
30 |
hydra-core
|
31 |
+
librosa
|
32 |
# rembg==2.0.65
|
wan/configs/__init__.py
CHANGED
@@ -44,6 +44,8 @@ SUPPORTED_SIZES = {
|
|
44 |
VACE_SIZE_CONFIGS = {
|
45 |
'480*832': (480, 832),
|
46 |
'832*480': (832, 480),
|
|
|
|
|
47 |
}
|
48 |
|
49 |
VACE_MAX_AREA_CONFIGS = {
|
|
|
44 |
VACE_SIZE_CONFIGS = {
|
45 |
'480*832': (480, 832),
|
46 |
'832*480': (832, 480),
|
47 |
+
'720*1280': (720, 1280),
|
48 |
+
'1280*720': (1280, 720),
|
49 |
}
|
50 |
|
51 |
VACE_MAX_AREA_CONFIGS = {
|
wan/diffusion_forcing.py
CHANGED
@@ -56,16 +56,18 @@ class DTT2V:
|
|
56 |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
57 |
device=self.device)
|
58 |
|
59 |
-
logging.info(f"Creating WanModel from {model_filename}")
|
60 |
from mmgp import offload
|
61 |
# model_filename = "model.safetensors"
|
62 |
-
|
|
|
63 |
# offload.load_model_data(self.model, "recam.ckpt")
|
64 |
# self.model.cpu()
|
65 |
-
|
|
|
66 |
offload.change_dtype(self.model, dtype, True)
|
67 |
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
|
68 |
-
# offload.save_model(self.model, "
|
69 |
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
|
70 |
|
71 |
self.model.eval().requires_grad_(False)
|
@@ -200,6 +202,9 @@ class DTT2V:
|
|
200 |
fps: int = 24,
|
201 |
VAE_tile_size = 0,
|
202 |
joint_pass = False,
|
|
|
|
|
|
|
203 |
callback = None,
|
204 |
):
|
205 |
self._interrupt = False
|
@@ -211,6 +216,7 @@ class DTT2V:
|
|
211 |
|
212 |
if ar_step == 0:
|
213 |
causal_block_size = 1
|
|
|
214 |
|
215 |
i2v_extra_kwrags = {}
|
216 |
prefix_video = None
|
@@ -252,31 +258,33 @@ class DTT2V:
|
|
252 |
prefix_video = output_video.to(self.device)
|
253 |
else:
|
254 |
causal_block_size = 1
|
|
|
255 |
ar_step = 0
|
256 |
prefix_video = image
|
257 |
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
|
258 |
if prefix_video.dtype == torch.uint8:
|
259 |
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
|
260 |
prefix_video = prefix_video.to(self.device)
|
261 |
-
prefix_video =
|
262 |
-
predix_video_latent_length = prefix_video
|
263 |
truncate_len = predix_video_latent_length % causal_block_size
|
264 |
if truncate_len != 0:
|
265 |
if truncate_len == predix_video_latent_length:
|
266 |
causal_block_size = 1
|
|
|
|
|
267 |
else:
|
268 |
print("the length of prefix video is truncated for the casual block size alignment.")
|
269 |
predix_video_latent_length -= truncate_len
|
270 |
-
prefix_video
|
271 |
|
272 |
base_num_frames_iter = latent_length
|
273 |
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
|
274 |
latents = self.prepare_latents(
|
275 |
latent_shape, dtype=torch.float32, device=self.device, generator=generator
|
276 |
)
|
277 |
-
latents = [latents]
|
278 |
if prefix_video is not None:
|
279 |
-
latents[
|
280 |
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
281 |
base_num_frames_iter,
|
282 |
init_timesteps,
|
@@ -298,6 +306,8 @@ class DTT2V:
|
|
298 |
if callback != None:
|
299 |
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
300 |
if self.model.enable_teacache:
|
|
|
|
|
301 |
time_steps_comb = []
|
302 |
self.model.num_steps = updated_num_steps
|
303 |
for i, timestep_i in enumerate(step_matrix):
|
@@ -309,7 +319,7 @@ class DTT2V:
|
|
309 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
|
310 |
del time_steps_comb
|
311 |
from mmgp import offload
|
312 |
-
freqs = get_rotary_pos_embed(latents
|
313 |
kwrags = {
|
314 |
"freqs" :freqs,
|
315 |
"fps" : fps_embeds,
|
@@ -320,27 +330,27 @@ class DTT2V:
|
|
320 |
}
|
321 |
kwrags.update(i2v_extra_kwrags)
|
322 |
|
323 |
-
|
324 |
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
|
|
|
|
325 |
offload.set_step_no_for_lora(self.model, i)
|
326 |
update_mask_i = step_update_mask[i]
|
327 |
valid_interval_start, valid_interval_end = valid_interval[i]
|
328 |
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
329 |
-
latent_model_input =
|
330 |
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
331 |
noise_factor = 0.001 * addnoise_condition
|
332 |
timestep_for_noised_condition = addnoise_condition
|
333 |
-
latent_model_input[
|
334 |
-
latent_model_input[
|
335 |
* (1.0 - noise_factor)
|
336 |
+ torch.randn_like(
|
337 |
-
latent_model_input[
|
338 |
)
|
339 |
* noise_factor
|
340 |
)
|
341 |
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
342 |
kwrags.update({
|
343 |
-
"x" : torch.stack([latent_model_input[0]]),
|
344 |
"t" : timestep,
|
345 |
"current_step" : i,
|
346 |
})
|
@@ -349,6 +359,7 @@ class DTT2V:
|
|
349 |
if True:
|
350 |
if not self.do_classifier_free_guidance:
|
351 |
noise_pred = self.model(
|
|
|
352 |
context=[prompt_embeds],
|
353 |
**kwrags,
|
354 |
)[0]
|
@@ -358,6 +369,7 @@ class DTT2V:
|
|
358 |
else:
|
359 |
if joint_pass:
|
360 |
noise_pred_cond, noise_pred_uncond = self.model(
|
|
|
361 |
context= [prompt_embeds, negative_prompt_embeds],
|
362 |
**kwrags,
|
363 |
)
|
@@ -365,12 +377,16 @@ class DTT2V:
|
|
365 |
return None
|
366 |
else:
|
367 |
noise_pred_cond = self.model(
|
|
|
|
|
368 |
context=[prompt_embeds],
|
369 |
**kwrags,
|
370 |
)[0]
|
371 |
if self._interrupt:
|
372 |
return None
|
373 |
noise_pred_uncond = self.model(
|
|
|
|
|
374 |
context=[negative_prompt_embeds],
|
375 |
**kwrags,
|
376 |
)[0]
|
@@ -380,18 +396,18 @@ class DTT2V:
|
|
380 |
del noise_pred_cond, noise_pred_uncond
|
381 |
for idx in range(valid_interval_start, valid_interval_end):
|
382 |
if update_mask_i[idx].item():
|
383 |
-
latents[
|
384 |
noise_pred[:, idx - valid_interval_start],
|
385 |
timestep_i[idx],
|
386 |
-
latents[
|
387 |
return_dict=False,
|
388 |
generator=generator,
|
389 |
)[0]
|
390 |
sample_schedulers_counter[idx] += 1
|
391 |
if callback is not None:
|
392 |
-
callback(i, latents
|
393 |
|
394 |
-
x0 = latents
|
395 |
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
|
396 |
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
|
397 |
return output_video
|
|
|
56 |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
57 |
device=self.device)
|
58 |
|
59 |
+
logging.info(f"Creating WanModel from {model_filename[-1]}")
|
60 |
from mmgp import offload
|
61 |
# model_filename = "model.safetensors"
|
62 |
+
# model_filename = "c:/temp/diffusion_pytorch_model-00001-of-00006.safetensors"
|
63 |
+
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) # , forcedConfigPath="c:/temp/config _df720.json")
|
64 |
# offload.load_model_data(self.model, "recam.ckpt")
|
65 |
# self.model.cpu()
|
66 |
+
# dtype = torch.float16
|
67 |
+
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
68 |
offload.change_dtype(self.model, dtype, True)
|
69 |
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
|
70 |
+
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", do_quantize= True, config_file_path="c:/temp/config _df720.json")
|
71 |
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
|
72 |
|
73 |
self.model.eval().requires_grad_(False)
|
|
|
202 |
fps: int = 24,
|
203 |
VAE_tile_size = 0,
|
204 |
joint_pass = False,
|
205 |
+
slg_layers = None,
|
206 |
+
slg_start = 0.0,
|
207 |
+
slg_end = 1.0,
|
208 |
callback = None,
|
209 |
):
|
210 |
self._interrupt = False
|
|
|
216 |
|
217 |
if ar_step == 0:
|
218 |
causal_block_size = 1
|
219 |
+
causal_attention = False
|
220 |
|
221 |
i2v_extra_kwrags = {}
|
222 |
prefix_video = None
|
|
|
258 |
prefix_video = output_video.to(self.device)
|
259 |
else:
|
260 |
causal_block_size = 1
|
261 |
+
causal_attention = False
|
262 |
ar_step = 0
|
263 |
prefix_video = image
|
264 |
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
|
265 |
if prefix_video.dtype == torch.uint8:
|
266 |
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
|
267 |
prefix_video = prefix_video.to(self.device)
|
268 |
+
prefix_video = self.vae.encode(prefix_video.unsqueeze(0))[0] # [(c, f, h, w)]
|
269 |
+
predix_video_latent_length = prefix_video.shape[1]
|
270 |
truncate_len = predix_video_latent_length % causal_block_size
|
271 |
if truncate_len != 0:
|
272 |
if truncate_len == predix_video_latent_length:
|
273 |
causal_block_size = 1
|
274 |
+
causal_attention = False
|
275 |
+
ar_step = 0
|
276 |
else:
|
277 |
print("the length of prefix video is truncated for the casual block size alignment.")
|
278 |
predix_video_latent_length -= truncate_len
|
279 |
+
prefix_video = prefix_video[:, : predix_video_latent_length]
|
280 |
|
281 |
base_num_frames_iter = latent_length
|
282 |
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
|
283 |
latents = self.prepare_latents(
|
284 |
latent_shape, dtype=torch.float32, device=self.device, generator=generator
|
285 |
)
|
|
|
286 |
if prefix_video is not None:
|
287 |
+
latents[:, :predix_video_latent_length] = prefix_video.to(torch.float32)
|
288 |
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
289 |
base_num_frames_iter,
|
290 |
init_timesteps,
|
|
|
306 |
if callback != None:
|
307 |
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
308 |
if self.model.enable_teacache:
|
309 |
+
x_count = 2 if self.do_classifier_free_guidance else 1
|
310 |
+
self.model.previous_residual = [None] * x_count
|
311 |
time_steps_comb = []
|
312 |
self.model.num_steps = updated_num_steps
|
313 |
for i, timestep_i in enumerate(step_matrix):
|
|
|
319 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, time_steps_comb, self.model.teacache_multiplier)
|
320 |
del time_steps_comb
|
321 |
from mmgp import offload
|
322 |
+
freqs = get_rotary_pos_embed(latents.shape[1 :], enable_RIFLEx= False)
|
323 |
kwrags = {
|
324 |
"freqs" :freqs,
|
325 |
"fps" : fps_embeds,
|
|
|
330 |
}
|
331 |
kwrags.update(i2v_extra_kwrags)
|
332 |
|
|
|
333 |
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
334 |
+
kwrags["slg_layers"] = slg_layers if int(slg_start * updated_num_steps) <= i < int(slg_end * updated_num_steps) else None
|
335 |
+
|
336 |
offload.set_step_no_for_lora(self.model, i)
|
337 |
update_mask_i = step_update_mask[i]
|
338 |
valid_interval_start, valid_interval_end = valid_interval[i]
|
339 |
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
340 |
+
latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].clone()
|
341 |
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
342 |
noise_factor = 0.001 * addnoise_condition
|
343 |
timestep_for_noised_condition = addnoise_condition
|
344 |
+
latent_model_input[:, valid_interval_start:predix_video_latent_length] = (
|
345 |
+
latent_model_input[:, valid_interval_start:predix_video_latent_length]
|
346 |
* (1.0 - noise_factor)
|
347 |
+ torch.randn_like(
|
348 |
+
latent_model_input[:, valid_interval_start:predix_video_latent_length]
|
349 |
)
|
350 |
* noise_factor
|
351 |
)
|
352 |
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
353 |
kwrags.update({
|
|
|
354 |
"t" : timestep,
|
355 |
"current_step" : i,
|
356 |
})
|
|
|
359 |
if True:
|
360 |
if not self.do_classifier_free_guidance:
|
361 |
noise_pred = self.model(
|
362 |
+
x=[latent_model_input],
|
363 |
context=[prompt_embeds],
|
364 |
**kwrags,
|
365 |
)[0]
|
|
|
369 |
else:
|
370 |
if joint_pass:
|
371 |
noise_pred_cond, noise_pred_uncond = self.model(
|
372 |
+
x=[latent_model_input, latent_model_input],
|
373 |
context= [prompt_embeds, negative_prompt_embeds],
|
374 |
**kwrags,
|
375 |
)
|
|
|
377 |
return None
|
378 |
else:
|
379 |
noise_pred_cond = self.model(
|
380 |
+
x=[latent_model_input],
|
381 |
+
x_id=0,
|
382 |
context=[prompt_embeds],
|
383 |
**kwrags,
|
384 |
)[0]
|
385 |
if self._interrupt:
|
386 |
return None
|
387 |
noise_pred_uncond = self.model(
|
388 |
+
x=[latent_model_input],
|
389 |
+
x_id=1,
|
390 |
context=[negative_prompt_embeds],
|
391 |
**kwrags,
|
392 |
)[0]
|
|
|
396 |
del noise_pred_cond, noise_pred_uncond
|
397 |
for idx in range(valid_interval_start, valid_interval_end):
|
398 |
if update_mask_i[idx].item():
|
399 |
+
latents[:, idx] = sample_schedulers[idx].step(
|
400 |
noise_pred[:, idx - valid_interval_start],
|
401 |
timestep_i[idx],
|
402 |
+
latents[:, idx],
|
403 |
return_dict=False,
|
404 |
generator=generator,
|
405 |
)[0]
|
406 |
sample_schedulers_counter[idx] += 1
|
407 |
if callback is not None:
|
408 |
+
callback(i, latents.squeeze(0), False)
|
409 |
|
410 |
+
x0 = latents.unsqueeze(0)
|
411 |
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
|
412 |
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
|
413 |
return output_video
|
wan/image2video.py
CHANGED
@@ -8,7 +8,7 @@ import sys
|
|
8 |
import types
|
9 |
from contextlib import contextmanager
|
10 |
from functools import partial
|
11 |
-
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import torch.cuda.amp as amp
|
@@ -84,13 +84,29 @@ class WanI2V:
|
|
84 |
config.clip_checkpoint),
|
85 |
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
86 |
|
87 |
-
logging.info(f"Creating WanModel from {model_filename}")
|
88 |
from mmgp import offload
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
# offload.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
96 |
self.model.eval().requires_grad_(False)
|
@@ -102,6 +118,8 @@ class WanI2V:
|
|
102 |
input_prompt,
|
103 |
img,
|
104 |
img2 = None,
|
|
|
|
|
105 |
max_area=720 * 1280,
|
106 |
frame_num=81,
|
107 |
shift=5.0,
|
@@ -119,7 +137,11 @@ class WanI2V:
|
|
119 |
slg_end = 1.0,
|
120 |
cfg_star_switch = True,
|
121 |
cfg_zero_step = 5,
|
122 |
-
add_frames_for_end_image = True
|
|
|
|
|
|
|
|
|
123 |
):
|
124 |
r"""
|
125 |
Generates video frames from input image and text prompt using diffusion process.
|
@@ -167,17 +189,25 @@ class WanI2V:
|
|
167 |
frame_num +=1
|
168 |
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
169 |
|
|
|
170 |
h, w = img.shape[1:]
|
171 |
-
aspect_ratio = h / w
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
lat_h = round(
|
173 |
-
|
174 |
self.patch_size[1] * self.patch_size[1])
|
175 |
lat_w = round(
|
176 |
-
|
177 |
self.patch_size[2] * self.patch_size[2])
|
178 |
h = lat_h * self.vae_stride[1]
|
179 |
w = lat_w * self.vae_stride[2]
|
180 |
-
|
181 |
clip_image_size = self.clip.model.image_size
|
182 |
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
183 |
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
@@ -271,98 +301,101 @@ class WanI2V:
|
|
271 |
|
272 |
# sample videos
|
273 |
latent = noise
|
274 |
-
batch_size =
|
275 |
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
276 |
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
}
|
285 |
-
|
286 |
-
arg_null = {
|
287 |
-
'context': [context_null],
|
288 |
-
'clip_fea': clip_context,
|
289 |
-
'y': [y],
|
290 |
-
'freqs' : freqs,
|
291 |
-
'pipeline' : self,
|
292 |
-
'callback' : callback
|
293 |
-
}
|
294 |
-
|
295 |
-
arg_both= {
|
296 |
-
'context': [context, context_null],
|
297 |
-
'clip_fea': clip_context,
|
298 |
-
'y': [y],
|
299 |
-
'freqs' : freqs,
|
300 |
-
'pipeline' : self,
|
301 |
-
'callback' : callback
|
302 |
-
}
|
303 |
|
304 |
if self.model.enable_teacache:
|
|
|
305 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
306 |
|
307 |
# self.model.to(self.device)
|
308 |
if callback != None:
|
309 |
callback(-1, None, True)
|
310 |
-
|
311 |
for i, t in enumerate(tqdm(timesteps)):
|
312 |
offload.set_step_no_for_lora(self.model, i)
|
313 |
-
|
314 |
-
|
315 |
-
slg_layers_local = slg_layers
|
316 |
-
|
317 |
-
latent_model_input = [latent.to(self.device)]
|
318 |
timestep = [t]
|
319 |
|
320 |
timestep = torch.stack(timestep).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
if joint_pass:
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
if self._interrupt:
|
325 |
return None
|
326 |
else:
|
327 |
noise_pred_cond = self.model(
|
328 |
-
latent_model_input,
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
**
|
333 |
)[0]
|
334 |
if self._interrupt:
|
335 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
noise_pred_uncond = self.model(
|
337 |
-
latent_model_input,
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
slg_layers=slg_layers_local,
|
342 |
-
**arg_null,
|
343 |
)[0]
|
344 |
if self._interrupt:
|
345 |
return None
|
346 |
del latent_model_input
|
347 |
|
348 |
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
349 |
-
noise_pred_text = noise_pred_cond
|
350 |
if cfg_star_switch:
|
351 |
-
positive_flat =
|
352 |
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
353 |
|
354 |
alpha = optimized_scale(positive_flat,negative_flat)
|
355 |
alpha = alpha.view(batch_size, 1, 1, 1)
|
356 |
|
357 |
-
|
358 |
if (i <= cfg_zero_step):
|
359 |
-
noise_pred =
|
360 |
else:
|
361 |
noise_pred_uncond *= alpha
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
366 |
temp_x0 = sample_scheduler.step(
|
367 |
noise_pred.unsqueeze(0),
|
368 |
t,
|
@@ -376,9 +409,6 @@ class WanI2V:
|
|
376 |
if callback is not None:
|
377 |
callback(i, latent, False)
|
378 |
|
379 |
-
|
380 |
-
# x0 = [latent.to(self.device, dtype=self.dtype)]
|
381 |
-
|
382 |
x0 = [latent]
|
383 |
|
384 |
# x0 = [lat_y]
|
|
|
8 |
import types
|
9 |
from contextlib import contextmanager
|
10 |
from functools import partial
|
11 |
+
import json
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import torch.cuda.amp as amp
|
|
|
84 |
config.clip_checkpoint),
|
85 |
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
86 |
|
87 |
+
logging.info(f"Creating WanModel from {model_filename[-1]}")
|
88 |
from mmgp import offload
|
89 |
|
90 |
+
# fantasy = torch.load("c:/temp/fantasy.ckpt")
|
91 |
+
# proj_model = fantasy["proj_model"]
|
92 |
+
# audio_processor = fantasy["audio_processor"]
|
93 |
+
# offload.safetensors2.torch_write_file(proj_model, "proj_model.safetensors")
|
94 |
+
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor.safetensors")
|
95 |
+
# for k,v in audio_processor.items():
|
96 |
+
# audio_processor[k] = v.to(torch.bfloat16)
|
97 |
+
# with open("fantasy_config.json", "r", encoding="utf-8") as reader:
|
98 |
+
# config_text = reader.read()
|
99 |
+
# config_json = json.loads(config_text)
|
100 |
+
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor_bf16.safetensors", config=config_json)
|
101 |
+
# model_filename = [model_filename, "audio_processor_bf16.safetensors"]
|
102 |
+
# model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors"
|
103 |
+
# dtype = torch.float16
|
104 |
+
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
|
105 |
+
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
106 |
+
# offload.change_dtype(self.model, dtype, True)
|
107 |
+
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
|
108 |
+
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
109 |
+
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
110 |
|
111 |
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
112 |
self.model.eval().requires_grad_(False)
|
|
|
118 |
input_prompt,
|
119 |
img,
|
120 |
img2 = None,
|
121 |
+
height =720,
|
122 |
+
width = 1280,
|
123 |
max_area=720 * 1280,
|
124 |
frame_num=81,
|
125 |
shift=5.0,
|
|
|
137 |
slg_end = 1.0,
|
138 |
cfg_star_switch = True,
|
139 |
cfg_zero_step = 5,
|
140 |
+
add_frames_for_end_image = True,
|
141 |
+
audio_scale=None,
|
142 |
+
audio_cfg_scale=None,
|
143 |
+
audio_proj=None,
|
144 |
+
audio_context_lens=None,
|
145 |
):
|
146 |
r"""
|
147 |
Generates video frames from input image and text prompt using diffusion process.
|
|
|
189 |
frame_num +=1
|
190 |
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
191 |
|
192 |
+
|
193 |
h, w = img.shape[1:]
|
194 |
+
# aspect_ratio = h / w
|
195 |
+
|
196 |
+
scale1 = min(height / h, width / w)
|
197 |
+
scale2 = min(height / h, width / w)
|
198 |
+
scale = max(scale1, scale2)
|
199 |
+
new_height = int(h * scale)
|
200 |
+
new_width = int(w * scale)
|
201 |
+
|
202 |
lat_h = round(
|
203 |
+
new_height // self.vae_stride[1] //
|
204 |
self.patch_size[1] * self.patch_size[1])
|
205 |
lat_w = round(
|
206 |
+
new_width // self.vae_stride[2] //
|
207 |
self.patch_size[2] * self.patch_size[2])
|
208 |
h = lat_h * self.vae_stride[1]
|
209 |
w = lat_w * self.vae_stride[2]
|
210 |
+
|
211 |
clip_image_size = self.clip.model.image_size
|
212 |
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
213 |
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
|
|
301 |
|
302 |
# sample videos
|
303 |
latent = noise
|
304 |
+
batch_size = 1
|
305 |
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
306 |
|
307 |
+
kwargs = { 'clip_fea': clip_context, 'y': y, 'freqs' : freqs, 'pipeline' : self, 'callback' : callback }
|
308 |
+
|
309 |
+
if audio_proj != None:
|
310 |
+
kwargs.update({
|
311 |
+
"audio_proj": audio_proj.to(self.dtype),
|
312 |
+
"audio_context_lens": audio_context_lens,
|
313 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
if self.model.enable_teacache:
|
316 |
+
self.model.previous_residual = [None] * (3 if audio_cfg_scale !=None else 2)
|
317 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
318 |
|
319 |
# self.model.to(self.device)
|
320 |
if callback != None:
|
321 |
callback(-1, None, True)
|
322 |
+
latent = latent.to(self.device)
|
323 |
for i, t in enumerate(tqdm(timesteps)):
|
324 |
offload.set_step_no_for_lora(self.model, i)
|
325 |
+
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
326 |
+
latent_model_input = latent
|
|
|
|
|
|
|
327 |
timestep = [t]
|
328 |
|
329 |
timestep = torch.stack(timestep).to(self.device)
|
330 |
+
kwargs.update({
|
331 |
+
't' :timestep,
|
332 |
+
'current_step' :i,
|
333 |
+
})
|
334 |
+
|
335 |
+
|
336 |
if joint_pass:
|
337 |
+
if audio_proj == None:
|
338 |
+
noise_pred_cond, noise_pred_uncond = self.model(
|
339 |
+
[latent_model_input, latent_model_input],
|
340 |
+
context=[context, context_null],
|
341 |
+
**kwargs)
|
342 |
+
else:
|
343 |
+
noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = self.model(
|
344 |
+
[latent_model_input, latent_model_input, latent_model_input],
|
345 |
+
context=[context, context, context_null],
|
346 |
+
audio_scale = [audio_scale, None, None ],
|
347 |
+
**kwargs)
|
348 |
+
|
349 |
if self._interrupt:
|
350 |
return None
|
351 |
else:
|
352 |
noise_pred_cond = self.model(
|
353 |
+
[latent_model_input],
|
354 |
+
context=[context],
|
355 |
+
audio_scale = None if audio_scale == None else [audio_scale],
|
356 |
+
x_id=0,
|
357 |
+
**kwargs,
|
358 |
)[0]
|
359 |
if self._interrupt:
|
360 |
+
return None
|
361 |
+
|
362 |
+
if audio_proj != None:
|
363 |
+
noise_pred_noaudio = self.model(
|
364 |
+
[latent_model_input],
|
365 |
+
x_id=1,
|
366 |
+
context=[context],
|
367 |
+
**kwargs,
|
368 |
+
)[0]
|
369 |
+
if self._interrupt:
|
370 |
+
return None
|
371 |
+
|
372 |
noise_pred_uncond = self.model(
|
373 |
+
[latent_model_input],
|
374 |
+
x_id=1 if audio_scale == None else 2,
|
375 |
+
context=[context_null],
|
376 |
+
**kwargs,
|
|
|
|
|
377 |
)[0]
|
378 |
if self._interrupt:
|
379 |
return None
|
380 |
del latent_model_input
|
381 |
|
382 |
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
|
|
383 |
if cfg_star_switch:
|
384 |
+
positive_flat = noise_pred_cond.view(batch_size, -1)
|
385 |
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
386 |
|
387 |
alpha = optimized_scale(positive_flat,negative_flat)
|
388 |
alpha = alpha.view(batch_size, 1, 1, 1)
|
389 |
|
|
|
390 |
if (i <= cfg_zero_step):
|
391 |
+
noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
|
392 |
else:
|
393 |
noise_pred_uncond *= alpha
|
394 |
+
if audio_scale == None:
|
395 |
+
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
|
396 |
+
else:
|
397 |
+
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
398 |
+
noise_pred_uncond, noise_pred_noaudio = None, None
|
399 |
temp_x0 = sample_scheduler.step(
|
400 |
noise_pred.unsqueeze(0),
|
401 |
t,
|
|
|
409 |
if callback is not None:
|
410 |
callback(i, latent, False)
|
411 |
|
|
|
|
|
|
|
412 |
x0 = [latent]
|
413 |
|
414 |
# x0 = [lat_y]
|
wan/modules/attention.py
CHANGED
@@ -57,15 +57,15 @@ def sageattn_wrapper(
|
|
57 |
):
|
58 |
q,k, v = qkv_list
|
59 |
padding_length = q.shape[0] -attention_length
|
60 |
-
q = q[:attention_length, :, : ]
|
61 |
-
k = k[:attention_length, :, : ]
|
62 |
-
v = v[:attention_length, :, : ]
|
63 |
if True:
|
64 |
qkv_list = [q,k,v]
|
65 |
del q, k ,v
|
66 |
-
o = alt_sageattn(qkv_list, tensor_layout="NHD")
|
67 |
else:
|
68 |
-
o = sageattn(q, k, v, tensor_layout="NHD")
|
69 |
del q, k ,v
|
70 |
|
71 |
qkv_list.clear()
|
@@ -107,14 +107,14 @@ def sdpa_wrapper(
|
|
107 |
attention_length
|
108 |
):
|
109 |
q,k, v = qkv_list
|
110 |
-
padding_length = q.shape[
|
111 |
-
q = q[:attention_length, :].transpose(
|
112 |
-
k = k[:attention_length, :].transpose(
|
113 |
-
v = v[:attention_length, :].transpose(
|
114 |
|
115 |
o = F.scaled_dot_product_attention(
|
116 |
q, k, v, attn_mask=None, is_causal=False
|
117 |
-
).
|
118 |
del q, k ,v
|
119 |
qkv_list.clear()
|
120 |
|
@@ -159,36 +159,72 @@ def pay_attention(
|
|
159 |
deterministic=False,
|
160 |
version=None,
|
161 |
force_attention= None,
|
162 |
-
cross_attn= False
|
|
|
163 |
):
|
164 |
|
165 |
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
166 |
q,k,v = qkv_list
|
167 |
qkv_list.clear()
|
168 |
|
169 |
-
|
170 |
# params
|
171 |
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
172 |
-
assert b==1
|
173 |
-
q = q.squeeze(0)
|
174 |
-
k = k.squeeze(0)
|
175 |
-
v = v.squeeze(0)
|
176 |
-
|
177 |
|
178 |
q = q.to(v.dtype)
|
179 |
k = k.to(v.dtype)
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
185 |
warnings.warn(
|
186 |
'Flash attention 3 is not available, use flash attention 2 instead.'
|
187 |
)
|
188 |
|
189 |
if attn=="sage" or attn=="flash":
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
# apply attention
|
194 |
if attn=="sage":
|
@@ -207,7 +243,7 @@ def pay_attention(
|
|
207 |
qkv_list = [q,k,v]
|
208 |
del q,k,v
|
209 |
|
210 |
-
x = sageattn_wrapper(qkv_list, lq)
|
211 |
# else:
|
212 |
# layer = offload.shared_state["layer"]
|
213 |
# embed_sizes = offload.shared_state["embed_sizes"]
|
@@ -267,8 +303,8 @@ def pay_attention(
|
|
267 |
|
268 |
elif attn=="sdpa":
|
269 |
qkv_list = [q, k, v]
|
270 |
-
del q,
|
271 |
-
x = sdpa_wrapper( qkv_list, lq)
|
272 |
elif attn=="flash" and version == 3:
|
273 |
# Note: dropout_p, window_size are not supported in FA3 now.
|
274 |
x = flash_attn_interface.flash_attn_varlen_func(
|
@@ -302,59 +338,11 @@ def pay_attention(
|
|
302 |
# output
|
303 |
|
304 |
elif attn=="xformers":
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
v
|
309 |
-
|
|
|
310 |
|
311 |
-
return x.type(out_dtype)
|
312 |
-
|
313 |
-
|
314 |
-
def attention(
|
315 |
-
q,
|
316 |
-
k,
|
317 |
-
v,
|
318 |
-
q_lens=None,
|
319 |
-
k_lens=None,
|
320 |
-
dropout_p=0.,
|
321 |
-
softmax_scale=None,
|
322 |
-
q_scale=None,
|
323 |
-
causal=False,
|
324 |
-
window_size=(-1, -1),
|
325 |
-
deterministic=False,
|
326 |
-
dtype=torch.bfloat16,
|
327 |
-
fa_version=None,
|
328 |
-
):
|
329 |
-
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
330 |
-
return pay_attention(
|
331 |
-
q=q,
|
332 |
-
k=k,
|
333 |
-
v=v,
|
334 |
-
q_lens=q_lens,
|
335 |
-
k_lens=k_lens,
|
336 |
-
dropout_p=dropout_p,
|
337 |
-
softmax_scale=softmax_scale,
|
338 |
-
q_scale=q_scale,
|
339 |
-
causal=causal,
|
340 |
-
window_size=window_size,
|
341 |
-
deterministic=deterministic,
|
342 |
-
dtype=dtype,
|
343 |
-
version=fa_version,
|
344 |
-
)
|
345 |
-
else:
|
346 |
-
if q_lens is not None or k_lens is not None:
|
347 |
-
warnings.warn(
|
348 |
-
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
349 |
-
)
|
350 |
-
attn_mask = None
|
351 |
-
|
352 |
-
q = q.transpose(1, 2).to(dtype)
|
353 |
-
k = k.transpose(1, 2).to(dtype)
|
354 |
-
v = v.transpose(1, 2).to(dtype)
|
355 |
-
|
356 |
-
out = torch.nn.functional.scaled_dot_product_attention(
|
357 |
-
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
358 |
-
|
359 |
-
out = out.transpose(1, 2).contiguous()
|
360 |
-
return out
|
|
|
57 |
):
|
58 |
q,k, v = qkv_list
|
59 |
padding_length = q.shape[0] -attention_length
|
60 |
+
q = q[:attention_length, :, : ]
|
61 |
+
k = k[:attention_length, :, : ]
|
62 |
+
v = v[:attention_length, :, : ]
|
63 |
if True:
|
64 |
qkv_list = [q,k,v]
|
65 |
del q, k ,v
|
66 |
+
o = alt_sageattn(qkv_list, tensor_layout="NHD")
|
67 |
else:
|
68 |
+
o = sageattn(q, k, v, tensor_layout="NHD")
|
69 |
del q, k ,v
|
70 |
|
71 |
qkv_list.clear()
|
|
|
107 |
attention_length
|
108 |
):
|
109 |
q,k, v = qkv_list
|
110 |
+
padding_length = q.shape[1] -attention_length
|
111 |
+
q = q[:attention_length, :].transpose(1,2)
|
112 |
+
k = k[:attention_length, :].transpose(1,2)
|
113 |
+
v = v[:attention_length, :].transpose(1,2)
|
114 |
|
115 |
o = F.scaled_dot_product_attention(
|
116 |
q, k, v, attn_mask=None, is_causal=False
|
117 |
+
).transpose(1,2)
|
118 |
del q, k ,v
|
119 |
qkv_list.clear()
|
120 |
|
|
|
159 |
deterministic=False,
|
160 |
version=None,
|
161 |
force_attention= None,
|
162 |
+
cross_attn= False,
|
163 |
+
k_lens = None
|
164 |
):
|
165 |
|
166 |
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
167 |
q,k,v = qkv_list
|
168 |
qkv_list.clear()
|
169 |
|
|
|
170 |
# params
|
171 |
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
q = q.to(v.dtype)
|
174 |
k = k.to(v.dtype)
|
175 |
+
if b > 0 and k_lens != None and attn in ("sage2", "sdpa"):
|
176 |
+
# Poor's man var len attention
|
177 |
+
chunk_sizes = []
|
178 |
+
k_sizes = []
|
179 |
+
current_size = k_lens[0]
|
180 |
+
current_count= 1
|
181 |
+
for k_len in k_lens[1:]:
|
182 |
+
if k_len == current_size:
|
183 |
+
current_count += 1
|
184 |
+
else:
|
185 |
+
chunk_sizes.append(current_count)
|
186 |
+
k_sizes.append(current_size)
|
187 |
+
current_count = 1
|
188 |
+
current_size = k_len
|
189 |
+
chunk_sizes.append(current_count)
|
190 |
+
k_sizes.append(k_len)
|
191 |
+
if len(chunk_sizes) > 1 or k_lens[0] != k.shape[1]:
|
192 |
+
q_chunks =torch.split(q, chunk_sizes)
|
193 |
+
k_chunks =torch.split(k, chunk_sizes)
|
194 |
+
v_chunks =torch.split(v, chunk_sizes)
|
195 |
+
q, k, v = None, None, None
|
196 |
+
k_chunks = [ u[:, :sz] for u, sz in zip(k_chunks, k_sizes)]
|
197 |
+
v_chunks = [ u[:, :sz] for u, sz in zip(v_chunks, k_sizes)]
|
198 |
+
o = []
|
199 |
+
for sub_q, sub_k, sub_v in zip(q_chunks, k_chunks, v_chunks):
|
200 |
+
qkv_list = [sub_q, sub_k, sub_v]
|
201 |
+
sub_q, sub_k, sub_v = None, None, None
|
202 |
+
o.append( pay_attention(qkv_list) )
|
203 |
+
q_chunks, k_chunks, v_chunks = None, None, None
|
204 |
+
o = torch.cat(o, dim = 0)
|
205 |
+
return o
|
206 |
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
207 |
warnings.warn(
|
208 |
'Flash attention 3 is not available, use flash attention 2 instead.'
|
209 |
)
|
210 |
|
211 |
if attn=="sage" or attn=="flash":
|
212 |
+
if b != 1 :
|
213 |
+
if k_lens == None:
|
214 |
+
k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
|
215 |
+
k = torch.cat([u[:v] for u, v in zip(k, k_lens)])
|
216 |
+
v = torch.cat([u[:v] for u, v in zip(v, k_lens)])
|
217 |
+
q = q.reshape(-1, *q.shape[-2:])
|
218 |
+
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
|
219 |
+
cu_seqlens_q=torch.cat([k_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
|
220 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
|
221 |
+
else:
|
222 |
+
cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda")
|
223 |
+
cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda")
|
224 |
+
q = q.squeeze(0)
|
225 |
+
k = k.squeeze(0)
|
226 |
+
v = v.squeeze(0)
|
227 |
+
|
228 |
|
229 |
# apply attention
|
230 |
if attn=="sage":
|
|
|
243 |
qkv_list = [q,k,v]
|
244 |
del q,k,v
|
245 |
|
246 |
+
x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
|
247 |
# else:
|
248 |
# layer = offload.shared_state["layer"]
|
249 |
# embed_sizes = offload.shared_state["embed_sizes"]
|
|
|
303 |
|
304 |
elif attn=="sdpa":
|
305 |
qkv_list = [q, k, v]
|
306 |
+
del q ,k ,v
|
307 |
+
x = sdpa_wrapper( qkv_list, lq) #.unsqueeze(0)
|
308 |
elif attn=="flash" and version == 3:
|
309 |
# Note: dropout_p, window_size are not supported in FA3 now.
|
310 |
x = flash_attn_interface.flash_attn_varlen_func(
|
|
|
338 |
# output
|
339 |
|
340 |
elif attn=="xformers":
|
341 |
+
from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
|
342 |
+
if b != 1 and k_lens != None:
|
343 |
+
attn_mask = BlockDiagonalPaddedKeysMask.from_seqlens([lq] * b , lk, list(k_lens) )
|
344 |
+
x = memory_efficient_attention(q, k, v, attn_bias= attn_mask )
|
345 |
+
else:
|
346 |
+
x = memory_efficient_attention(q, k, v )
|
347 |
|
348 |
+
return x.type(out_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wan/modules/model.py
CHANGED
@@ -197,9 +197,9 @@ class WanSelfAttention(nn.Module):
|
|
197 |
del q,k
|
198 |
|
199 |
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
200 |
-
qkv_list = [q,k,v]
|
201 |
-
del q,k,v
|
202 |
if block_mask == None:
|
|
|
|
|
203 |
x = pay_attention(
|
204 |
qkv_list,
|
205 |
window_size=self.window_size)
|
@@ -212,6 +212,7 @@ class WanSelfAttention(nn.Module):
|
|
212 |
.transpose(1, 2)
|
213 |
.contiguous()
|
214 |
)
|
|
|
215 |
|
216 |
# if not self._flag_ar_attention:
|
217 |
# q = rope_apply(q, grid_sizes, freqs)
|
@@ -241,7 +242,7 @@ class WanSelfAttention(nn.Module):
|
|
241 |
|
242 |
class WanT2VCrossAttention(WanSelfAttention):
|
243 |
|
244 |
-
def forward(self, xlist, context):
|
245 |
r"""
|
246 |
Args:
|
247 |
x(Tensor): Shape [B, L1, C]
|
@@ -262,6 +263,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|
262 |
v = self.v(context).view(b, -1, n, d)
|
263 |
|
264 |
# compute attention
|
|
|
265 |
qvl_list=[q, k, v]
|
266 |
del q, k, v
|
267 |
x = pay_attention(qvl_list, cross_attn= True)
|
@@ -287,7 +289,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|
287 |
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
288 |
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
289 |
|
290 |
-
def forward(self, xlist, context):
|
291 |
r"""
|
292 |
Args:
|
293 |
x(Tensor): Shape [B, L1, C]
|
@@ -310,6 +312,8 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|
310 |
del x
|
311 |
self.norm_q(q)
|
312 |
q= q.view(b, -1, n, d)
|
|
|
|
|
313 |
k = self.k(context)
|
314 |
self.norm_k(k)
|
315 |
k = k.view(b, -1, n, d)
|
@@ -334,6 +338,8 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|
334 |
img_x = img_x.flatten(2)
|
335 |
x += img_x
|
336 |
del img_x
|
|
|
|
|
337 |
x = self.o(x)
|
338 |
return x
|
339 |
|
@@ -398,7 +404,10 @@ class WanAttentionBlock(nn.Module):
|
|
398 |
hints= None,
|
399 |
context_scale=1.0,
|
400 |
cam_emb= None,
|
401 |
-
block_mask = None
|
|
|
|
|
|
|
402 |
):
|
403 |
r"""
|
404 |
Args:
|
@@ -433,7 +442,7 @@ class WanAttentionBlock(nn.Module):
|
|
433 |
if cam_emb != None:
|
434 |
cam_emb = self.cam_encoder(cam_emb)
|
435 |
cam_emb = cam_emb.repeat(1, 2, 1)
|
436 |
-
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[
|
437 |
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
|
438 |
x_mod += cam_emb
|
439 |
|
@@ -453,7 +462,7 @@ class WanAttentionBlock(nn.Module):
|
|
453 |
y = y.to(attention_dtype)
|
454 |
ylist= [y]
|
455 |
del y
|
456 |
-
x += self.cross_attn(ylist, context).to(dtype)
|
457 |
|
458 |
y = self.norm2(x)
|
459 |
|
@@ -610,6 +619,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
610 |
eps=1e-6,
|
611 |
recammaster = False,
|
612 |
inject_sample_info = False,
|
|
|
613 |
):
|
614 |
r"""
|
615 |
Initialize the diffusion model backbone.
|
@@ -742,43 +752,48 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
742 |
block.projector.weight = nn.Parameter(torch.eye(dim))
|
743 |
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
745 |
|
746 |
-
def lock_layers_dtypes(self, dtype = torch.float32, force = False):
|
747 |
-
count = 0
|
748 |
-
layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2],
|
749 |
-
self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
|
750 |
if hasattr(self, "fps_embedding"):
|
751 |
-
|
752 |
|
753 |
if hasattr(self, "vace_patch_embedding"):
|
754 |
-
|
755 |
-
|
756 |
for block in self.vace_blocks:
|
757 |
-
|
|
|
|
|
758 |
|
759 |
# cam master
|
760 |
if hasattr(self.blocks[0], "projector"):
|
761 |
for block in self.blocks:
|
762 |
-
|
763 |
|
764 |
-
for
|
765 |
-
|
766 |
-
|
767 |
-
if hasattr(layer, "weight"):
|
768 |
-
if layer.weight.dtype == dtype :
|
769 |
-
count += 1
|
770 |
-
elif force:
|
771 |
-
if hasattr(layer, "weight"):
|
772 |
-
layer.weight.data = layer.weight.data.to(dtype)
|
773 |
-
if hasattr(layer, "bias"):
|
774 |
-
layer.bias.data = layer.bias.data.to(dtype)
|
775 |
-
count += 1
|
776 |
-
|
777 |
-
layer._lock_dtype = dtype
|
778 |
|
|
|
|
|
|
|
|
|
779 |
|
780 |
-
|
781 |
-
self._lock_dtype = dtype
|
782 |
|
783 |
|
784 |
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
@@ -788,7 +803,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
788 |
t = torch.stack([t])
|
789 |
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
|
790 |
e_list.append(time_emb)
|
791 |
-
|
792 |
best_threshold = 0.01
|
793 |
best_diff = 1000
|
794 |
best_signed_diff = 1000
|
@@ -798,12 +813,16 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
798 |
accumulated_rel_l1_distance =0
|
799 |
nb_steps = 0
|
800 |
diff = 1000
|
|
|
801 |
for i, t in enumerate(timesteps):
|
802 |
skip = False
|
803 |
-
if not (i<=start_step or i== len(timesteps)):
|
804 |
-
|
|
|
|
|
805 |
if accumulated_rel_l1_distance < threshold:
|
806 |
skip = True
|
|
|
807 |
else:
|
808 |
accumulated_rel_l1_distance = 0
|
809 |
if not skip:
|
@@ -812,6 +831,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
812 |
diff = abs(signed_diff)
|
813 |
if diff < best_diff:
|
814 |
best_threshold = threshold
|
|
|
815 |
best_diff = diff
|
816 |
best_signed_diff = signed_diff
|
817 |
elif diff > best_diff:
|
@@ -819,6 +839,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
819 |
threshold += 0.01
|
820 |
self.rel_l1_thresh = best_threshold
|
821 |
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
|
|
822 |
return best_threshold
|
823 |
|
824 |
|
@@ -834,7 +855,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
834 |
freqs = None,
|
835 |
pipeline = None,
|
836 |
current_step = 0,
|
837 |
-
|
838 |
max_steps = 0,
|
839 |
slg_layers=None,
|
840 |
callback = None,
|
@@ -842,10 +863,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
842 |
fps = None,
|
843 |
causal_block_size = 1,
|
844 |
causal_attention = False,
|
845 |
-
|
|
|
|
|
|
|
846 |
):
|
847 |
-
#
|
848 |
-
|
849 |
|
850 |
if self.model_type == 'i2v':
|
851 |
assert clip_fea is not None and y is not None
|
@@ -854,20 +878,32 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
854 |
if torch.is_tensor(freqs) and freqs.device != device:
|
855 |
freqs = freqs.to(device)
|
856 |
|
857 |
-
if y is not None:
|
858 |
-
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
859 |
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
871 |
block_num = frame_num // causal_block_size
|
872 |
range_tensor = torch.arange(block_num).view(-1, 1)
|
873 |
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
|
@@ -878,30 +914,21 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
878 |
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
879 |
del causal_mask
|
880 |
|
881 |
-
offload.shared_state["embed_sizes"] =
|
882 |
offload.shared_state["step_no"] = current_step
|
883 |
offload.shared_state["max_steps"] = max_steps
|
884 |
|
885 |
-
|
886 |
-
x = x[0]
|
887 |
-
if x_neg !=None:
|
888 |
-
x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
|
889 |
-
x_neg = x_neg[0]
|
890 |
|
891 |
-
if t.dim() == 2:
|
892 |
-
b, f = t.shape
|
893 |
-
_flag_df = True
|
894 |
-
else:
|
895 |
-
_flag_df = False
|
896 |
e = self.time_embedding(
|
897 |
-
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(
|
898 |
) # b, dim
|
899 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
900 |
|
901 |
if self.inject_sample_info:
|
902 |
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
903 |
|
904 |
-
fps_emb = self.fps_embedding(fps).to(dtype)
|
905 |
if _flag_df:
|
906 |
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
907 |
else:
|
@@ -913,30 +940,28 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
913 |
if clip_fea is not None:
|
914 |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
915 |
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
|
916 |
-
|
917 |
-
joint_pass = len(context) > 0
|
918 |
-
x_list = [x]
|
919 |
-
if joint_pass:
|
920 |
-
if x_neg == None:
|
921 |
-
x_list += [x.clone() for i in range(len(context) - 1) ]
|
922 |
-
else:
|
923 |
-
x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg]
|
924 |
-
is_uncond = False
|
925 |
-
del x
|
926 |
context_list = context
|
|
|
|
|
|
|
|
|
927 |
|
928 |
# arguments
|
929 |
|
930 |
kwargs = dict(
|
931 |
grid_sizes=grid_sizes,
|
932 |
freqs=freqs,
|
933 |
-
cam_emb = cam_emb
|
|
|
|
|
|
|
934 |
)
|
935 |
|
936 |
if vace_context == None:
|
937 |
hints_list = [None ] *len(x_list)
|
938 |
else:
|
939 |
-
# embeddings
|
940 |
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
941 |
c = [u.flatten(2).transpose(1, 2) for u in c]
|
942 |
c = c[0]
|
@@ -947,7 +972,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
947 |
|
948 |
should_calc = True
|
949 |
if self.enable_teacache:
|
950 |
-
if
|
951 |
should_calc = self.should_calc
|
952 |
else:
|
953 |
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
|
@@ -955,11 +980,12 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
955 |
self.accumulated_rel_l1_distance = 0
|
956 |
else:
|
957 |
rescale_func = np.poly1d(self.coefficients)
|
958 |
-
|
|
|
959 |
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
960 |
should_calc = False
|
961 |
self.teacache_skipped_steps += 1
|
962 |
-
# print(f"Teacache Skipped Step
|
963 |
else:
|
964 |
should_calc = True
|
965 |
self.accumulated_rel_l1_distance = 0
|
@@ -967,15 +993,23 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
967 |
self.should_calc = should_calc
|
968 |
|
969 |
if not should_calc:
|
970 |
-
|
971 |
-
|
|
|
|
|
|
|
|
|
|
|
972 |
else:
|
973 |
if self.enable_teacache:
|
974 |
-
if joint_pass
|
975 |
-
self.
|
976 |
-
|
977 |
-
self.
|
978 |
-
ori_hidden_states =
|
|
|
|
|
|
|
979 |
|
980 |
for block_idx, block in enumerate(self.blocks):
|
981 |
offload.shared_state["layer"] = block_idx
|
@@ -984,29 +1018,30 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
984 |
if pipeline._interrupt:
|
985 |
return [None] * len(x_list)
|
986 |
|
987 |
-
if slg_layers is not None and block_idx in slg_layers:
|
988 |
-
if
|
989 |
continue
|
990 |
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
|
991 |
-
|
992 |
else:
|
993 |
-
for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
|
994 |
-
x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
|
995 |
del x
|
996 |
del context, hints
|
997 |
|
998 |
if self.enable_teacache:
|
999 |
if joint_pass:
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
|
|
|
|
|
|
|
|
|
|
1003 |
else:
|
1004 |
-
residual = ori_hidden_states # just to have a readable code
|
1005 |
-
torch.sub(x_list[0], ori_hidden_states, out=residual)
|
1006 |
-
|
1007 |
-
self.previous_residual_uncond = residual
|
1008 |
-
else:
|
1009 |
-
self.previous_residual_cond = residual
|
1010 |
residual, ori_hidden_states = None, None
|
1011 |
|
1012 |
for i, x in enumerate(x_list):
|
@@ -1037,10 +1072,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
1037 |
|
1038 |
c = self.out_dim
|
1039 |
out = []
|
1040 |
-
for u
|
1041 |
-
u = u[:math.prod(
|
1042 |
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
1043 |
-
u = u.reshape(c, *[i * j for i, j in zip(
|
1044 |
out.append(u)
|
1045 |
return out
|
1046 |
|
|
|
197 |
del q,k
|
198 |
|
199 |
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
|
|
|
|
200 |
if block_mask == None:
|
201 |
+
qkv_list = [q,k,v]
|
202 |
+
del q,k,v
|
203 |
x = pay_attention(
|
204 |
qkv_list,
|
205 |
window_size=self.window_size)
|
|
|
212 |
.transpose(1, 2)
|
213 |
.contiguous()
|
214 |
)
|
215 |
+
del q,k,v
|
216 |
|
217 |
# if not self._flag_ar_attention:
|
218 |
# q = rope_apply(q, grid_sizes, freqs)
|
|
|
242 |
|
243 |
class WanT2VCrossAttention(WanSelfAttention):
|
244 |
|
245 |
+
def forward(self, xlist, context, grid_sizes, *args, **kwargs):
|
246 |
r"""
|
247 |
Args:
|
248 |
x(Tensor): Shape [B, L1, C]
|
|
|
263 |
v = self.v(context).view(b, -1, n, d)
|
264 |
|
265 |
# compute attention
|
266 |
+
v = v.contiguous().clone()
|
267 |
qvl_list=[q, k, v]
|
268 |
del q, k, v
|
269 |
x = pay_attention(qvl_list, cross_attn= True)
|
|
|
289 |
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
290 |
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
291 |
|
292 |
+
def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ):
|
293 |
r"""
|
294 |
Args:
|
295 |
x(Tensor): Shape [B, L1, C]
|
|
|
312 |
del x
|
313 |
self.norm_q(q)
|
314 |
q= q.view(b, -1, n, d)
|
315 |
+
if audio_scale != None:
|
316 |
+
audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
|
317 |
k = self.k(context)
|
318 |
self.norm_k(k)
|
319 |
k = k.view(b, -1, n, d)
|
|
|
338 |
img_x = img_x.flatten(2)
|
339 |
x += img_x
|
340 |
del img_x
|
341 |
+
if audio_scale != None:
|
342 |
+
x.add_(audio_x, alpha= audio_scale)
|
343 |
x = self.o(x)
|
344 |
return x
|
345 |
|
|
|
404 |
hints= None,
|
405 |
context_scale=1.0,
|
406 |
cam_emb= None,
|
407 |
+
block_mask = None,
|
408 |
+
audio_proj= None,
|
409 |
+
audio_context_lens= None,
|
410 |
+
audio_scale=None,
|
411 |
):
|
412 |
r"""
|
413 |
Args:
|
|
|
442 |
if cam_emb != None:
|
443 |
cam_emb = self.cam_encoder(cam_emb)
|
444 |
cam_emb = cam_emb.repeat(1, 2, 1)
|
445 |
+
cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[1], grid_sizes[2], 1)
|
446 |
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
|
447 |
x_mod += cam_emb
|
448 |
|
|
|
462 |
y = y.to(attention_dtype)
|
463 |
ylist= [y]
|
464 |
del y
|
465 |
+
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
|
466 |
|
467 |
y = self.norm2(x)
|
468 |
|
|
|
619 |
eps=1e-6,
|
620 |
recammaster = False,
|
621 |
inject_sample_info = False,
|
622 |
+
fantasytalking_dim = 0,
|
623 |
):
|
624 |
r"""
|
625 |
Initialize the diffusion model backbone.
|
|
|
752 |
block.projector.weight = nn.Parameter(torch.eye(dim))
|
753 |
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
754 |
|
755 |
+
if fantasytalking_dim > 0:
|
756 |
+
from fantasytalking.model import WanCrossAttentionProcessor
|
757 |
+
for block in self.blocks:
|
758 |
+
block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim)
|
759 |
+
|
760 |
+
|
761 |
+
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
|
762 |
+
layer_list = [self.head, self.head.head, self.patch_embedding]
|
763 |
+
target_dype= dtype
|
764 |
+
|
765 |
+
layer_list2 = [ self.time_embedding, self.time_embedding[0], self.time_embedding[2],
|
766 |
+
self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
|
767 |
+
|
768 |
+
for block in self.blocks:
|
769 |
+
layer_list2 += [block.norm3]
|
770 |
|
|
|
|
|
|
|
|
|
771 |
if hasattr(self, "fps_embedding"):
|
772 |
+
layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
|
773 |
|
774 |
if hasattr(self, "vace_patch_embedding"):
|
775 |
+
layer_list2 += [self.vace_patch_embedding]
|
776 |
+
layer_list2 += [self.vace_blocks[0].before_proj]
|
777 |
for block in self.vace_blocks:
|
778 |
+
layer_list2 += [block.after_proj, block.norm3]
|
779 |
+
|
780 |
+
target_dype2 = hybrid_dtype if hybrid_dtype != None else dtype
|
781 |
|
782 |
# cam master
|
783 |
if hasattr(self.blocks[0], "projector"):
|
784 |
for block in self.blocks:
|
785 |
+
layer_list2 += [block.projector]
|
786 |
|
787 |
+
for current_layer_list, current_dtype in zip([layer_list, layer_list2], [target_dype, target_dype2]):
|
788 |
+
for layer in current_layer_list:
|
789 |
+
layer._lock_dtype = dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
|
791 |
+
if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
|
792 |
+
layer.weight.data = layer.weight.data.to(current_dtype)
|
793 |
+
if hasattr(layer, "bias"):
|
794 |
+
layer.bias.data = layer.bias.data.to(current_dtype)
|
795 |
|
796 |
+
self._lock_dtype = dtype
|
|
|
797 |
|
798 |
|
799 |
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
|
|
803 |
t = torch.stack([t])
|
804 |
time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
|
805 |
e_list.append(time_emb)
|
806 |
+
best_deltas = None
|
807 |
best_threshold = 0.01
|
808 |
best_diff = 1000
|
809 |
best_signed_diff = 1000
|
|
|
813 |
accumulated_rel_l1_distance =0
|
814 |
nb_steps = 0
|
815 |
diff = 1000
|
816 |
+
deltas = []
|
817 |
for i, t in enumerate(timesteps):
|
818 |
skip = False
|
819 |
+
if not (i<=start_step or i== len(timesteps)-1):
|
820 |
+
delta = abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
|
821 |
+
# deltas.append(delta)
|
822 |
+
accumulated_rel_l1_distance += delta
|
823 |
if accumulated_rel_l1_distance < threshold:
|
824 |
skip = True
|
825 |
+
# deltas.append("SKIP")
|
826 |
else:
|
827 |
accumulated_rel_l1_distance = 0
|
828 |
if not skip:
|
|
|
831 |
diff = abs(signed_diff)
|
832 |
if diff < best_diff:
|
833 |
best_threshold = threshold
|
834 |
+
best_deltas = deltas
|
835 |
best_diff = diff
|
836 |
best_signed_diff = signed_diff
|
837 |
elif diff > best_diff:
|
|
|
839 |
threshold += 0.01
|
840 |
self.rel_l1_thresh = best_threshold
|
841 |
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
842 |
+
# print(f"deltas:{best_deltas}")
|
843 |
return best_threshold
|
844 |
|
845 |
|
|
|
855 |
freqs = None,
|
856 |
pipeline = None,
|
857 |
current_step = 0,
|
858 |
+
x_id= 0,
|
859 |
max_steps = 0,
|
860 |
slg_layers=None,
|
861 |
callback = None,
|
|
|
863 |
fps = None,
|
864 |
causal_block_size = 1,
|
865 |
causal_attention = False,
|
866 |
+
audio_proj=None,
|
867 |
+
audio_context_lens=None,
|
868 |
+
audio_scale=None,
|
869 |
+
|
870 |
):
|
871 |
+
# patch_dtype = self.patch_embedding.weight.dtype
|
872 |
+
modulation_dtype = self.time_projection[1].weight.dtype
|
873 |
|
874 |
if self.model_type == 'i2v':
|
875 |
assert clip_fea is not None and y is not None
|
|
|
878 |
if torch.is_tensor(freqs) and freqs.device != device:
|
879 |
freqs = freqs.to(device)
|
880 |
|
|
|
|
|
881 |
|
882 |
+
x_list = x
|
883 |
+
joint_pass = len(x_list) > 1
|
884 |
+
is_source_x = [ x.data_ptr() == x_list[0].data_ptr() and i > 0 for i, x in enumerate(x_list) ]
|
885 |
+
last_x_idx = 0
|
886 |
+
for i, (is_source, x) in enumerate(zip(is_source_x, x_list)):
|
887 |
+
if is_source:
|
888 |
+
x_list[i] = x_list[0].clone()
|
889 |
+
last_x_idx = i
|
890 |
+
else:
|
891 |
+
# image source
|
892 |
+
if y is not None:
|
893 |
+
x = torch.cat([x, y], dim=0)
|
894 |
+
# embeddings
|
895 |
+
x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
|
896 |
+
grid_sizes = x.shape[2:]
|
897 |
+
x = x.flatten(2).transpose(1, 2)
|
898 |
+
x_list[i] = x
|
899 |
+
x, y = None, None
|
900 |
+
|
901 |
+
|
902 |
+
block_mask = None
|
903 |
+
if causal_attention and causal_block_size > 0 and False: # NEVER WORKED
|
904 |
+
frame_num = grid_sizes[0]
|
905 |
+
height = grid_sizes[1]
|
906 |
+
width = grid_sizes[2]
|
907 |
block_num = frame_num // causal_block_size
|
908 |
range_tensor = torch.arange(block_num).view(-1, 1)
|
909 |
range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
|
|
|
914 |
block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
915 |
del causal_mask
|
916 |
|
917 |
+
offload.shared_state["embed_sizes"] = grid_sizes
|
918 |
offload.shared_state["step_no"] = current_step
|
919 |
offload.shared_state["max_steps"] = max_steps
|
920 |
|
921 |
+
_flag_df = t.dim() == 2
|
|
|
|
|
|
|
|
|
922 |
|
|
|
|
|
|
|
|
|
|
|
923 |
e = self.time_embedding(
|
924 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(modulation_dtype) # self.patch_embedding.weight.dtype)
|
925 |
) # b, dim
|
926 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
927 |
|
928 |
if self.inject_sample_info:
|
929 |
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
930 |
|
931 |
+
fps_emb = self.fps_embedding(fps).to(e.dtype)
|
932 |
if _flag_df:
|
933 |
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
934 |
else:
|
|
|
940 |
if clip_fea is not None:
|
941 |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
942 |
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
|
943 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
944 |
context_list = context
|
945 |
+
if audio_scale != None:
|
946 |
+
audio_scale_list = audio_scale
|
947 |
+
else:
|
948 |
+
audio_scale_list = [None] * len(x_list)
|
949 |
|
950 |
# arguments
|
951 |
|
952 |
kwargs = dict(
|
953 |
grid_sizes=grid_sizes,
|
954 |
freqs=freqs,
|
955 |
+
cam_emb = cam_emb,
|
956 |
+
block_mask = block_mask,
|
957 |
+
audio_proj=audio_proj,
|
958 |
+
audio_context_lens=audio_context_lens,
|
959 |
)
|
960 |
|
961 |
if vace_context == None:
|
962 |
hints_list = [None ] *len(x_list)
|
963 |
else:
|
964 |
+
# Vace embeddings
|
965 |
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
966 |
c = [u.flatten(2).transpose(1, 2) for u in c]
|
967 |
c = c[0]
|
|
|
972 |
|
973 |
should_calc = True
|
974 |
if self.enable_teacache:
|
975 |
+
if x_id != 0:
|
976 |
should_calc = self.should_calc
|
977 |
else:
|
978 |
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
|
|
|
980 |
self.accumulated_rel_l1_distance = 0
|
981 |
else:
|
982 |
rescale_func = np.poly1d(self.coefficients)
|
983 |
+
delta = abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
|
984 |
+
self.accumulated_rel_l1_distance += delta
|
985 |
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
986 |
should_calc = False
|
987 |
self.teacache_skipped_steps += 1
|
988 |
+
# print(f"Teacache Skipped Step no {current_step} ({self.teacache_skipped_steps}/{current_step}), delta={delta}" )
|
989 |
else:
|
990 |
should_calc = True
|
991 |
self.accumulated_rel_l1_distance = 0
|
|
|
993 |
self.should_calc = should_calc
|
994 |
|
995 |
if not should_calc:
|
996 |
+
if joint_pass:
|
997 |
+
for i, x in enumerate(x_list):
|
998 |
+
x += self.previous_residual[i]
|
999 |
+
else:
|
1000 |
+
x = x_list[0]
|
1001 |
+
x += self.previous_residual[x_id]
|
1002 |
+
x = None
|
1003 |
else:
|
1004 |
if self.enable_teacache:
|
1005 |
+
if joint_pass:
|
1006 |
+
self.previous_residual = [ None ] * len(self.previous_residual)
|
1007 |
+
else:
|
1008 |
+
self.previous_residual[x_id] = None
|
1009 |
+
ori_hidden_states = [ None ] * len(x_list)
|
1010 |
+
ori_hidden_states[0] = x_list[0].clone()
|
1011 |
+
for i in range(1, len(x_list)):
|
1012 |
+
ori_hidden_states[i] = ori_hidden_states[0] if is_source_x[i] else x_list[i].clone()
|
1013 |
|
1014 |
for block_idx, block in enumerate(self.blocks):
|
1015 |
offload.shared_state["layer"] = block_idx
|
|
|
1018 |
if pipeline._interrupt:
|
1019 |
return [None] * len(x_list)
|
1020 |
|
1021 |
+
if (x_id != 0 or joint_pass) and slg_layers is not None and block_idx in slg_layers:
|
1022 |
+
if not joint_pass:
|
1023 |
continue
|
1024 |
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
|
|
|
1025 |
else:
|
1026 |
+
for i, (x, context, hints, audio_scale) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list)):
|
1027 |
+
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, e= e0, **kwargs)
|
1028 |
del x
|
1029 |
del context, hints
|
1030 |
|
1031 |
if self.enable_teacache:
|
1032 |
if joint_pass:
|
1033 |
+
for i, (x, ori, is_source) in enumerate(zip(x_list, ori_hidden_states, is_source_x)) :
|
1034 |
+
if i == 0 or is_source and i != last_x_idx :
|
1035 |
+
self.previous_residual[i] = torch.sub(x, ori)
|
1036 |
+
else:
|
1037 |
+
self.previous_residual[i] = ori
|
1038 |
+
torch.sub(x, ori, out=self.previous_residual[i])
|
1039 |
+
ori_hidden_states[i] = None
|
1040 |
+
x , ori = None, None
|
1041 |
else:
|
1042 |
+
residual = ori_hidden_states[0] # just to have a readable code
|
1043 |
+
torch.sub(x_list[0], ori_hidden_states[0], out=residual)
|
1044 |
+
self.previous_residual[x_id] = residual
|
|
|
|
|
|
|
1045 |
residual, ori_hidden_states = None, None
|
1046 |
|
1047 |
for i, x in enumerate(x_list):
|
|
|
1072 |
|
1073 |
c = self.out_dim
|
1074 |
out = []
|
1075 |
+
for u in x:
|
1076 |
+
u = u[:math.prod(grid_sizes)].view(*grid_sizes, *self.patch_size, c)
|
1077 |
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
1078 |
+
u = u.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
1079 |
out.append(u)
|
1080 |
return out
|
1081 |
|
wan/modules/sage2_core.py
CHANGED
@@ -140,7 +140,7 @@ def sageattn(
|
|
140 |
elif arch == "sm90":
|
141 |
return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
|
142 |
elif arch == "sm120":
|
143 |
-
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
|
144 |
else:
|
145 |
raise ValueError(f"Unsupported CUDA architecture: {arch}")
|
146 |
|
|
|
140 |
elif arch == "sm90":
|
141 |
return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
|
142 |
elif arch == "sm120":
|
143 |
+
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32", smooth_v= True) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
|
144 |
else:
|
145 |
raise ValueError(f"Unsupported CUDA architecture: {arch}")
|
146 |
|
wan/text2video.py
CHANGED
@@ -78,15 +78,16 @@ class WanT2V:
|
|
78 |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
79 |
device=self.device)
|
80 |
|
81 |
-
logging.info(f"Creating WanModel from {model_filename}")
|
82 |
from mmgp import offload
|
83 |
# model_filename
|
|
|
84 |
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
|
85 |
# offload.load_model_data(self.model, "e:/vace.safetensors")
|
86 |
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
87 |
# self.model.to(torch.bfloat16)
|
88 |
# self.model.cpu()
|
89 |
-
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype
|
90 |
offload.change_dtype(self.model, dtype, True)
|
91 |
# offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
|
92 |
# offload.save_model(self.model, "phantom_1.3B.safetensors")
|
@@ -95,7 +96,7 @@ class WanT2V:
|
|
95 |
|
96 |
self.sample_neg_prompt = config.sample_neg_prompt
|
97 |
|
98 |
-
if "Vace" in model_filename:
|
99 |
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
100 |
min_area=480*832,
|
101 |
max_area=480*832,
|
@@ -107,7 +108,7 @@ class WanT2V:
|
|
107 |
|
108 |
self.adapt_vace_model()
|
109 |
|
110 |
-
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
|
111 |
if ref_images is None:
|
112 |
ref_images = [None] * len(frames)
|
113 |
else:
|
@@ -119,6 +120,11 @@ class WanT2V:
|
|
119 |
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
120 |
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
121 |
inactive = self.vae.encode(inactive, tile_size = tile_size)
|
|
|
|
|
|
|
|
|
|
|
122 |
reactive = self.vae.encode(reactive, tile_size = tile_size)
|
123 |
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
124 |
|
@@ -288,7 +294,10 @@ class WanT2V:
|
|
288 |
slg_end = 1.0,
|
289 |
cfg_star_switch = True,
|
290 |
cfg_zero_step = 5,
|
291 |
-
|
|
|
|
|
|
|
292 |
r"""
|
293 |
Generates video frames from text prompt using diffusion process.
|
294 |
|
@@ -343,20 +352,20 @@ class WanT2V:
|
|
343 |
size = (source_video.shape[2], source_video.shape[1])
|
344 |
source_video = source_video.to(dtype=self.dtype , device=self.device)
|
345 |
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
|
346 |
-
source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
|
347 |
del source_video
|
348 |
# Process target camera (recammaster)
|
349 |
from wan.utils.cammmaster_tools import get_camera_embedding
|
350 |
cam_emb = get_camera_embedding(target_camera)
|
351 |
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
352 |
|
353 |
-
if
|
354 |
# vace context encode
|
355 |
input_frames = [u.to(self.device) for u in input_frames]
|
356 |
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
357 |
input_masks = [u.to(self.device) for u in input_masks]
|
358 |
|
359 |
-
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
|
360 |
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
361 |
z = self.vace_latent(z0, m0)
|
362 |
|
@@ -365,10 +374,10 @@ class WanT2V:
|
|
365 |
else:
|
366 |
if input_ref_images != None: # Phantom Ref images
|
367 |
phantom = True
|
368 |
-
input_ref_images =
|
369 |
-
input_ref_images_neg =
|
370 |
F = frame_num
|
371 |
-
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images
|
372 |
size[1] // self.vae_stride[1],
|
373 |
size[0] // self.vae_stride[2])
|
374 |
|
@@ -405,37 +414,48 @@ class WanT2V:
|
|
405 |
raise NotImplementedError("Unsupported solver.")
|
406 |
|
407 |
# sample videos
|
408 |
-
latents = noise
|
409 |
del noise
|
410 |
-
batch_size =
|
411 |
if target_camera != None:
|
412 |
-
shape = list(latents
|
413 |
shape[0] *= 2
|
414 |
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
415 |
else:
|
416 |
-
freqs = get_rotary_pos_embed(latents
|
417 |
|
418 |
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
|
419 |
|
420 |
if target_camera != None:
|
421 |
kwargs.update({'cam_emb': cam_emb})
|
422 |
|
423 |
-
if
|
|
|
424 |
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
|
|
|
|
|
425 |
|
426 |
|
427 |
if self.model.enable_teacache:
|
|
|
|
|
428 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
429 |
if callback != None:
|
430 |
callback(-1, None, True)
|
431 |
for i, t in enumerate(tqdm(timesteps)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
if target_camera != None:
|
433 |
-
latent_model_input =
|
434 |
else:
|
435 |
latent_model_input = latents
|
436 |
-
|
437 |
-
|
438 |
-
slg_layers_local = slg_layers
|
439 |
timestep = [t]
|
440 |
offload.set_step_no_for_lora(self.model, i)
|
441 |
timestep = torch.stack(timestep)
|
@@ -444,38 +464,38 @@ class WanT2V:
|
|
444 |
if joint_pass:
|
445 |
if phantom:
|
446 |
pos_it, pos_i, neg = self.model(
|
447 |
-
|
448 |
-
|
449 |
context = [context, context_null, context_null], **kwargs)
|
450 |
else:
|
451 |
noise_pred_cond, noise_pred_uncond = self.model(
|
452 |
-
latent_model_input,
|
453 |
if self._interrupt:
|
454 |
return None
|
455 |
else:
|
456 |
if phantom:
|
457 |
pos_it = self.model(
|
458 |
-
[torch.cat([
|
459 |
)[0]
|
460 |
if self._interrupt:
|
461 |
return None
|
462 |
pos_i = self.model(
|
463 |
-
[torch.cat([
|
464 |
)[0]
|
465 |
if self._interrupt:
|
466 |
return None
|
467 |
neg = self.model(
|
468 |
-
|
469 |
)[0]
|
470 |
if self._interrupt:
|
471 |
return None
|
472 |
else:
|
473 |
noise_pred_cond = self.model(
|
474 |
-
latent_model_input,
|
475 |
if self._interrupt:
|
476 |
return None
|
477 |
noise_pred_uncond = self.model(
|
478 |
-
latent_model_input,
|
479 |
if self._interrupt:
|
480 |
return None
|
481 |
|
@@ -505,21 +525,21 @@ class WanT2V:
|
|
505 |
temp_x0 = sample_scheduler.step(
|
506 |
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
507 |
t,
|
508 |
-
latents
|
509 |
return_dict=False,
|
510 |
generator=seed_g)[0]
|
511 |
-
latents =
|
512 |
del temp_x0
|
513 |
|
514 |
if callback is not None:
|
515 |
-
callback(i, latents
|
516 |
|
517 |
-
x0 = latents
|
518 |
|
519 |
if input_frames == None:
|
520 |
if phantom:
|
521 |
# phantom post processing
|
522 |
-
x0 = [x0_[:,:-input_ref_images
|
523 |
videos = self.vae.decode(x0, VAE_tile_size)
|
524 |
else:
|
525 |
# vace post processing
|
|
|
78 |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
79 |
device=self.device)
|
80 |
|
81 |
+
logging.info(f"Creating WanModel from {model_filename[-1]}")
|
82 |
from mmgp import offload
|
83 |
# model_filename
|
84 |
+
|
85 |
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
|
86 |
# offload.load_model_data(self.model, "e:/vace.safetensors")
|
87 |
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
88 |
# self.model.to(torch.bfloat16)
|
89 |
# self.model.cpu()
|
90 |
+
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
91 |
offload.change_dtype(self.model, dtype, True)
|
92 |
# offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
|
93 |
# offload.save_model(self.model, "phantom_1.3B.safetensors")
|
|
|
96 |
|
97 |
self.sample_neg_prompt = config.sample_neg_prompt
|
98 |
|
99 |
+
if "Vace" in model_filename[-1]:
|
100 |
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
101 |
min_area=480*832,
|
102 |
max_area=480*832,
|
|
|
108 |
|
109 |
self.adapt_vace_model()
|
110 |
|
111 |
+
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = 0, overlap_noise = 0):
|
112 |
if ref_images is None:
|
113 |
ref_images = [None] * len(frames)
|
114 |
else:
|
|
|
120 |
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
121 |
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
122 |
inactive = self.vae.encode(inactive, tile_size = tile_size)
|
123 |
+
# inactive = [ t * (1.0 - noise_factor) + torch.randn_like(t ) * noise_factor for t in inactive]
|
124 |
+
# if overlapped_latents > 0:
|
125 |
+
# for t in inactive:
|
126 |
+
# t[:, :overlapped_latents ] = t[:, :overlapped_latents ] * (1.0 - noise_factor) + torch.randn_like(t[:, :overlapped_latents ] ) * noise_factor
|
127 |
+
|
128 |
reactive = self.vae.encode(reactive, tile_size = tile_size)
|
129 |
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
130 |
|
|
|
294 |
slg_end = 1.0,
|
295 |
cfg_star_switch = True,
|
296 |
cfg_zero_step = 5,
|
297 |
+
overlapped_latents = 0,
|
298 |
+
overlap_noise = 0,
|
299 |
+
vace = False
|
300 |
+
):
|
301 |
r"""
|
302 |
Generates video frames from text prompt using diffusion process.
|
303 |
|
|
|
352 |
size = (source_video.shape[2], source_video.shape[1])
|
353 |
source_video = source_video.to(dtype=self.dtype , device=self.device)
|
354 |
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
|
355 |
+
source_latents = self.vae.encode([source_video])[0] #.to(dtype=self.dtype, device=self.device)
|
356 |
del source_video
|
357 |
# Process target camera (recammaster)
|
358 |
from wan.utils.cammmaster_tools import get_camera_embedding
|
359 |
cam_emb = get_camera_embedding(target_camera)
|
360 |
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
361 |
|
362 |
+
if vace :
|
363 |
# vace context encode
|
364 |
input_frames = [u.to(self.device) for u in input_frames]
|
365 |
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
366 |
input_masks = [u.to(self.device) for u in input_masks]
|
367 |
|
368 |
+
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents, overlap_noise = overlap_noise )
|
369 |
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
370 |
z = self.vace_latent(z0, m0)
|
371 |
|
|
|
374 |
else:
|
375 |
if input_ref_images != None: # Phantom Ref images
|
376 |
phantom = True
|
377 |
+
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
378 |
+
input_ref_images_neg = torch.zeros_like(input_ref_images)
|
379 |
F = frame_num
|
380 |
+
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images.shape[1] if input_ref_images != None else 0),
|
381 |
size[1] // self.vae_stride[1],
|
382 |
size[0] // self.vae_stride[2])
|
383 |
|
|
|
414 |
raise NotImplementedError("Unsupported solver.")
|
415 |
|
416 |
# sample videos
|
417 |
+
latents = noise[0]
|
418 |
del noise
|
419 |
+
batch_size = 1
|
420 |
if target_camera != None:
|
421 |
+
shape = list(latents.shape[1:])
|
422 |
shape[0] *= 2
|
423 |
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
424 |
else:
|
425 |
+
freqs = get_rotary_pos_embed(latents.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
426 |
|
427 |
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
|
428 |
|
429 |
if target_camera != None:
|
430 |
kwargs.update({'cam_emb': cam_emb})
|
431 |
|
432 |
+
if vace:
|
433 |
+
ref_images_count = len(input_ref_images[0]) if input_ref_images != None else 0
|
434 |
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
|
435 |
+
if overlapped_latents > 0:
|
436 |
+
z_reactive = [ zz[0:16, ref_images_count:overlapped_latents + ref_images_count].clone() for zz in z]
|
437 |
|
438 |
|
439 |
if self.model.enable_teacache:
|
440 |
+
x_count = 3 if phantom else 2
|
441 |
+
self.model.previous_residual = [None] * x_count
|
442 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
443 |
if callback != None:
|
444 |
callback(-1, None, True)
|
445 |
for i, t in enumerate(tqdm(timesteps)):
|
446 |
+
if vace and overlapped_latents > 0 :
|
447 |
+
# noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000
|
448 |
+
noise_factor = overlap_noise / 1000 # * (999-t) / 999
|
449 |
+
# noise_factor = overlap_noise / 1000 # * t / 999
|
450 |
+
for zz, zz_r in zip(z, z_reactive):
|
451 |
+
zz[0:16, ref_images_count:overlapped_latents + ref_images_count] = zz_r * (1.0 - noise_factor) + torch.randn_like(zz_r ) * noise_factor
|
452 |
+
|
453 |
if target_camera != None:
|
454 |
+
latent_model_input = torch.cat([latents, source_latents], dim=1)
|
455 |
else:
|
456 |
latent_model_input = latents
|
457 |
+
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
458 |
+
|
|
|
459 |
timestep = [t]
|
460 |
offload.set_step_no_for_lora(self.model, i)
|
461 |
timestep = torch.stack(timestep)
|
|
|
464 |
if joint_pass:
|
465 |
if phantom:
|
466 |
pos_it, pos_i, neg = self.model(
|
467 |
+
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ] * 2 +
|
468 |
+
[ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1)],
|
469 |
context = [context, context_null, context_null], **kwargs)
|
470 |
else:
|
471 |
noise_pred_cond, noise_pred_uncond = self.model(
|
472 |
+
[latent_model_input, latent_model_input], context = [context, context_null], **kwargs)
|
473 |
if self._interrupt:
|
474 |
return None
|
475 |
else:
|
476 |
if phantom:
|
477 |
pos_it = self.model(
|
478 |
+
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 0, context = [context], **kwargs
|
479 |
)[0]
|
480 |
if self._interrupt:
|
481 |
return None
|
482 |
pos_i = self.model(
|
483 |
+
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 1, context = [context_null],**kwargs
|
484 |
)[0]
|
485 |
if self._interrupt:
|
486 |
return None
|
487 |
neg = self.model(
|
488 |
+
[ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1) ], x_id = 2, context = [context_null], **kwargs
|
489 |
)[0]
|
490 |
if self._interrupt:
|
491 |
return None
|
492 |
else:
|
493 |
noise_pred_cond = self.model(
|
494 |
+
[latent_model_input], x_id = 0, context = [context], **kwargs)[0]
|
495 |
if self._interrupt:
|
496 |
return None
|
497 |
noise_pred_uncond = self.model(
|
498 |
+
[latent_model_input], x_id = 1, context = [context_null], **kwargs)[0]
|
499 |
if self._interrupt:
|
500 |
return None
|
501 |
|
|
|
525 |
temp_x0 = sample_scheduler.step(
|
526 |
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
527 |
t,
|
528 |
+
latents.unsqueeze(0),
|
529 |
return_dict=False,
|
530 |
generator=seed_g)[0]
|
531 |
+
latents = temp_x0.squeeze(0)
|
532 |
del temp_x0
|
533 |
|
534 |
if callback is not None:
|
535 |
+
callback(i, latents, False)
|
536 |
|
537 |
+
x0 = [latents]
|
538 |
|
539 |
if input_frames == None:
|
540 |
if phantom:
|
541 |
# phantom post processing
|
542 |
+
x0 = [x0_[:,:-input_ref_images.shape[1]] for x0_ in x0]
|
543 |
videos = self.vae.decode(x0, VAE_tile_size)
|
544 |
else:
|
545 |
# vace post processing
|
wgp.py
CHANGED
@@ -40,7 +40,7 @@ global_queue_ref = []
|
|
40 |
AUTOSAVE_FILENAME = "queue.zip"
|
41 |
PROMPT_VARS_MAX = 10
|
42 |
|
43 |
-
target_mmgp_version = "3.4.
|
44 |
from importlib.metadata import version
|
45 |
mmgp_version = version("mmgp")
|
46 |
if mmgp_version != target_mmgp_version:
|
@@ -49,32 +49,30 @@ if mmgp_version != target_mmgp_version:
|
|
49 |
lock = threading.Lock()
|
50 |
current_task_id = None
|
51 |
task_id = 0
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
# os.rename(f, os.path.basename(f))
|
77 |
-
# os.remove(zip_name)
|
78 |
|
79 |
def format_time(seconds):
|
80 |
if seconds < 60:
|
@@ -168,14 +166,14 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|
168 |
resolution = inputs["resolution"]
|
169 |
width, height = resolution.split("x")
|
170 |
width, height = int(width), int(height)
|
171 |
-
if test_class_i2v(model_filename):
|
172 |
-
if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
|
173 |
-
|
174 |
-
return
|
175 |
-
resolution = str(width) + "*" + str(height)
|
176 |
-
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
177 |
-
|
178 |
-
|
179 |
|
180 |
if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
|
181 |
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
@@ -533,7 +531,7 @@ def save_queue_action(state):
|
|
533 |
task_id_s = task.get('id', f"task_{task_index}")
|
534 |
|
535 |
image_keys = ["image_start", "image_end", "image_refs"]
|
536 |
-
video_keys = ["video_guide", "video_mask", "video_source"]
|
537 |
|
538 |
for key in image_keys:
|
539 |
images_pil = params_copy.get(key)
|
@@ -707,7 +705,7 @@ def load_queue_action(filepath, state, evt:gr.EventData):
|
|
707 |
params['state'] = state
|
708 |
|
709 |
image_keys = ["image_start", "image_end", "image_refs"]
|
710 |
-
video_keys = ["video_guide", "video_mask", "video_source"]
|
711 |
|
712 |
loaded_pil_images = {}
|
713 |
loaded_video_paths = {}
|
@@ -925,7 +923,7 @@ def autosave_queue():
|
|
925 |
task_id_s = task.get('id', f"task_{task_index}")
|
926 |
|
927 |
image_keys = ["image_start", "image_end", "image_refs"]
|
928 |
-
video_keys = ["video_guide", "video_mask", "video_source"]
|
929 |
|
930 |
for key in image_keys:
|
931 |
images_pil = params_copy.get(key)
|
@@ -1418,32 +1416,35 @@ else:
|
|
1418 |
text = reader.read()
|
1419 |
server_config = json.loads(text)
|
1420 |
|
1421 |
-
#
|
1422 |
-
|
1423 |
-
|
1424 |
-
|
1425 |
-
|
1426 |
-
|
|
|
1427 |
|
1428 |
-
path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"
|
1429 |
-
if os.path.isfile(path) and os.path.getsize(path) > 4000000000:
|
1430 |
-
os.remove(path)
|
1431 |
|
1432 |
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
|
1433 |
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
|
1434 |
-
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/
|
1435 |
"ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
|
1436 |
-
transformer_choices_i2v=["ckpts/wan2.
|
1437 |
-
"ckpts/wan2.
|
1438 |
-
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"
|
|
|
1439 |
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
1440 |
-
|
1441 |
-
|
|
|
|
|
|
|
|
|
1442 |
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
1443 |
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
|
1444 |
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
|
1445 |
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
|
1446 |
-
"phantom_1.3B" : "phantom_1.3B",
|
1447 |
|
1448 |
|
1449 |
def get_model_type(model_filename):
|
@@ -1453,7 +1454,7 @@ def get_model_type(model_filename):
|
|
1453 |
raise Exception("Unknown model:" + model_filename)
|
1454 |
|
1455 |
def test_class_i2v(model_filename):
|
1456 |
-
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
|
1457 |
|
1458 |
def get_model_name(model_filename, description_container = [""]):
|
1459 |
if "Fun" in model_filename:
|
@@ -1491,6 +1492,10 @@ def get_model_name(model_filename, description_container = [""]):
|
|
1491 |
model_name = "Wan2.1 Phantom"
|
1492 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
1493 |
description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
|
|
|
|
|
|
|
|
|
1494 |
else:
|
1495 |
model_name = "Wan2.1 text2video"
|
1496 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
@@ -1536,6 +1541,7 @@ def get_default_settings(filename):
|
|
1536 |
"repeat_generation": 1,
|
1537 |
"multi_images_gen_type": 0,
|
1538 |
"guidance_scale": 5.0,
|
|
|
1539 |
"flow_shift": get_default_flow(filename, i2v),
|
1540 |
"negative_prompt": "",
|
1541 |
"activated_loras": [],
|
@@ -1719,8 +1725,9 @@ def download_models(transformer_filename, text_encoder_filename):
|
|
1719 |
|
1720 |
from huggingface_hub import hf_hub_download, snapshot_download
|
1721 |
repoId = "DeepBeepMeep/Wan2.1"
|
1722 |
-
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
|
1723 |
-
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["
|
|
|
1724 |
targetRoot = "ckpts/"
|
1725 |
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
1726 |
if len(files)==0:
|
@@ -1834,12 +1841,13 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
|
|
1834 |
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
1835 |
|
1836 |
|
1837 |
-
def load_t2v_model(model_filename,
|
1838 |
|
1839 |
cfg = WAN_CONFIGS['t2v-14B']
|
|
|
1840 |
# cfg = WAN_CONFIGS['t2v-1.3B']
|
1841 |
-
print(f"Loading '{
|
1842 |
-
if
|
1843 |
model_factory = wan.DTT2V
|
1844 |
else:
|
1845 |
model_factory = wan.WanT2V
|
@@ -1859,9 +1867,10 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
|
|
1859 |
|
1860 |
return wan_model, pipe
|
1861 |
|
1862 |
-
def load_i2v_model(model_filename,
|
1863 |
|
1864 |
-
|
|
|
1865 |
|
1866 |
cfg = WAN_CONFIGS['i2v-14B']
|
1867 |
wan_model = wan.WanI2V(
|
@@ -1883,7 +1892,6 @@ def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = t
|
|
1883 |
def load_models(model_filename):
|
1884 |
global transformer_filename
|
1885 |
|
1886 |
-
transformer_filename = model_filename
|
1887 |
perc_reserved_mem_max = args.perc_reserved_mem_max
|
1888 |
|
1889 |
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
@@ -1892,19 +1900,27 @@ def load_models(model_filename):
|
|
1892 |
default_dtype = torch.float16
|
1893 |
else:
|
1894 |
default_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
1895 |
-
|
1896 |
-
|
1897 |
-
|
1898 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1899 |
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
1900 |
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
1901 |
-
|
1902 |
-
|
1903 |
-
|
|
|
1904 |
else:
|
1905 |
-
wan_model, pipe = load_t2v_model(
|
1906 |
-
wan_model._model_file_name =
|
1907 |
-
kwargs = { "extraModelsToQuantize": None}
|
1908 |
if profile == 2 or profile == 4:
|
1909 |
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
|
1910 |
# if profile == 4:
|
@@ -1914,7 +1930,7 @@ def load_models(model_filename):
|
|
1914 |
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
|
1915 |
if len(args.gpu) > 0:
|
1916 |
torch.set_default_device(args.gpu)
|
1917 |
-
|
1918 |
return wan_model, offloadobj, pipe["transformer"]
|
1919 |
|
1920 |
if not "P" in preload_model_policy:
|
@@ -2033,13 +2049,7 @@ def apply_changes( state,
|
|
2033 |
preload_model_policy = server_config["preload_model_policy"]
|
2034 |
transformer_quantization = server_config["transformer_quantization"]
|
2035 |
transformer_types = server_config["transformer_types"]
|
2036 |
-
model_filename = state["model_filename"]
|
2037 |
-
model_transformer_type = get_model_type(model_filename)
|
2038 |
|
2039 |
-
if not model_transformer_type in transformer_types:
|
2040 |
-
model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
|
2041 |
-
model_filename = get_model_filename(model_transformer_type, transformer_quantization)
|
2042 |
-
state["model_filename"] = model_filename
|
2043 |
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
|
2044 |
model_choice = gr.Dropdown()
|
2045 |
else:
|
@@ -2249,7 +2259,9 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
|
|
2249 |
frame_height, frame_width, _ = frames_list[0].shape
|
2250 |
|
2251 |
if fit_canvas :
|
2252 |
-
|
|
|
|
|
2253 |
else:
|
2254 |
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
|
2255 |
|
@@ -2356,6 +2368,7 @@ def generate_video(
|
|
2356 |
seed,
|
2357 |
num_inference_steps,
|
2358 |
guidance_scale,
|
|
|
2359 |
flow_shift,
|
2360 |
embedded_guidance_scale,
|
2361 |
repeat_generation,
|
@@ -2375,8 +2388,10 @@ def generate_video(
|
|
2375 |
video_guide,
|
2376 |
keep_frames_video_guide,
|
2377 |
video_mask,
|
|
|
2378 |
sliding_window_size,
|
2379 |
sliding_window_overlap,
|
|
|
2380 |
sliding_window_discard_last_frames,
|
2381 |
remove_background_image_ref,
|
2382 |
temporal_upsampling,
|
@@ -2508,6 +2523,15 @@ def generate_video(
|
|
2508 |
# VAE Tiling
|
2509 |
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
2510 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2511 |
joint_pass = boost ==1 #and profile != 1 and profile != 3
|
2512 |
# TeaCache
|
2513 |
trans.enable_teacache = tea_cache_setting > 0
|
@@ -2517,12 +2541,10 @@ def generate_video(
|
|
2517 |
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
2518 |
|
2519 |
if image2video:
|
2520 |
-
if '
|
2521 |
-
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
2522 |
-
elif '720p' in model_filename:
|
2523 |
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
2524 |
else:
|
2525 |
-
|
2526 |
else:
|
2527 |
if '1.3B' in model_filename:
|
2528 |
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
@@ -2535,6 +2557,18 @@ def generate_video(
|
|
2535 |
if "recam" in model_filename:
|
2536 |
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
|
2537 |
target_camera = model_mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2538 |
import random
|
2539 |
if seed == None or seed <0:
|
2540 |
seed = random.randint(0, 999999999)
|
@@ -2551,11 +2585,11 @@ def generate_video(
|
|
2551 |
extra_generation = 0
|
2552 |
initial_total_windows = 0
|
2553 |
max_frames_to_generate = video_length
|
2554 |
-
diffusion_forcing = "diffusion_forcing" in model_filename
|
2555 |
-
vace = "Vace" in model_filename
|
2556 |
phantom = "phantom" in model_filename
|
2557 |
if diffusion_forcing or vace:
|
2558 |
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
|
|
|
|
|
2559 |
if diffusion_forcing and source_video != None:
|
2560 |
video_length += sliding_window_overlap
|
2561 |
sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
|
@@ -2571,10 +2605,8 @@ def generate_video(
|
|
2571 |
initial_total_windows = 1
|
2572 |
|
2573 |
first_window_video_length = video_length
|
2574 |
-
fps = 24 if diffusion_forcing else 16
|
2575 |
|
2576 |
gen["sliding_window"] = sliding_window
|
2577 |
-
|
2578 |
while not abort:
|
2579 |
extra_generation += gen.get("extra_orders",0)
|
2580 |
gen["extra_orders"] = 0
|
@@ -2594,6 +2626,7 @@ def generate_video(
|
|
2594 |
guide_start_frame = 0
|
2595 |
video_length = first_window_video_length
|
2596 |
gen["extra_windows"] = 0
|
|
|
2597 |
while not abort:
|
2598 |
if sliding_window:
|
2599 |
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
|
@@ -2642,7 +2675,7 @@ def generate_video(
|
|
2642 |
|
2643 |
if preprocess_type != None :
|
2644 |
send_cmd("progress", progress_args)
|
2645 |
-
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, target_fps = fps)
|
2646 |
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
|
2647 |
if len(error) > 0:
|
2648 |
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
@@ -2678,8 +2711,7 @@ def generate_video(
|
|
2678 |
trans.teacache_counter = 0
|
2679 |
trans.num_steps = num_inference_steps
|
2680 |
trans.teacache_skipped_steps = 0
|
2681 |
-
trans.
|
2682 |
-
trans.previous_residual_cond = None
|
2683 |
|
2684 |
if image2video:
|
2685 |
samples = wan_model.generate(
|
@@ -2687,7 +2719,9 @@ def generate_video(
|
|
2687 |
image_start,
|
2688 |
image_end if image_end != None else None,
|
2689 |
frame_num=(video_length // 4)* 4 + 1,
|
2690 |
-
max_area=MAX_AREA_CONFIGS[resolution_reformated],
|
|
|
|
|
2691 |
shift=flow_shift,
|
2692 |
sampling_steps=num_inference_steps,
|
2693 |
guide_scale=guidance_scale,
|
@@ -2702,7 +2736,11 @@ def generate_video(
|
|
2702 |
slg_end = slg_end_perc/100,
|
2703 |
cfg_star_switch = cfg_star_switch,
|
2704 |
cfg_zero_step = cfg_zero_step,
|
2705 |
-
add_frames_for_end_image = "image2video" in model_filename
|
|
|
|
|
|
|
|
|
2706 |
)
|
2707 |
elif diffusion_forcing:
|
2708 |
samples = wan_model.generate(
|
@@ -2720,14 +2758,17 @@ def generate_video(
|
|
2720 |
callback= callback,
|
2721 |
VAE_tile_size = VAE_tile_size,
|
2722 |
joint_pass = joint_pass,
|
2723 |
-
|
|
|
|
|
|
|
2724 |
ar_step = model_mode, #5
|
2725 |
causal_block_size = 5,
|
2726 |
causal_attention = True,
|
2727 |
fps = fps,
|
2728 |
)
|
2729 |
else:
|
2730 |
-
|
2731 |
prompt,
|
2732 |
input_frames = src_video,
|
2733 |
input_ref_images= src_ref_images,
|
@@ -2751,6 +2792,9 @@ def generate_video(
|
|
2751 |
slg_end = slg_end_perc/100,
|
2752 |
cfg_star_switch = cfg_star_switch,
|
2753 |
cfg_zero_step = cfg_zero_step,
|
|
|
|
|
|
|
2754 |
)
|
2755 |
except Exception as e:
|
2756 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
@@ -2782,11 +2826,11 @@ def generate_video(
|
|
2782 |
print('\n'.join(tb))
|
2783 |
send_cmd("error", new_error)
|
2784 |
return
|
|
|
|
|
2785 |
|
2786 |
if trans.enable_teacache:
|
2787 |
-
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{
|
2788 |
-
trans.previous_residual_uncond = None
|
2789 |
-
trans.previous_residual_cond = None
|
2790 |
|
2791 |
if samples != None:
|
2792 |
samples = samples.to("cpu")
|
@@ -2810,14 +2854,27 @@ def generate_video(
|
|
2810 |
if discard_last_frames > 0:
|
2811 |
sample = sample[: , :-discard_last_frames]
|
2812 |
guide_start_frame -= discard_last_frames
|
2813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2814 |
if prefix_video != None:
|
2815 |
-
|
|
|
|
|
|
|
2816 |
prefix_video = None
|
2817 |
if sliding_window and window_no > 1:
|
2818 |
-
|
2819 |
-
|
|
|
|
|
2820 |
|
|
|
2821 |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
2822 |
if os.name == 'nt':
|
2823 |
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
@@ -2875,18 +2932,23 @@ def generate_video(
|
|
2875 |
sample = torch.cat([frames_already_processed, sample], dim=1)
|
2876 |
frames_already_processed = sample
|
2877 |
|
2878 |
-
|
2879 |
-
tensor=sample[None],
|
2880 |
-
|
2881 |
-
|
2882 |
-
nrow=1,
|
2883 |
-
|
2884 |
-
|
|
|
|
|
|
|
|
|
2885 |
|
2886 |
inputs = get_function_arguments(generate_video, locals())
|
2887 |
inputs.pop("send_cmd")
|
|
|
2888 |
configs = prepare_inputs_dict("metadata", inputs)
|
2889 |
-
|
2890 |
metadata_choice = server_config.get("metadata_type","metadata")
|
2891 |
if metadata_choice == "json":
|
2892 |
with open(video_path.replace('.mp4', '.json'), 'w') as f:
|
@@ -3113,7 +3175,7 @@ def get_latest_status(state):
|
|
3113 |
prompt_no = gen["prompt_no"]
|
3114 |
prompts_max = gen.get("prompts_max",0)
|
3115 |
total_generation = gen.get("total_generation", 1)
|
3116 |
-
repeat_no = gen
|
3117 |
total_generation += gen.get("extra_orders", 0)
|
3118 |
total_windows = gen.get("total_windows", 0)
|
3119 |
total_windows += gen.get("extra_windows", 0)
|
@@ -3456,7 +3518,7 @@ def prepare_inputs_dict(target, inputs ):
|
|
3456 |
|
3457 |
if target == "state":
|
3458 |
return inputs
|
3459 |
-
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"]
|
3460 |
for k in unsaved_params:
|
3461 |
inputs.pop(k)
|
3462 |
|
@@ -3484,10 +3546,14 @@ def prepare_inputs_dict(target, inputs ):
|
|
3484 |
inputs.pop(k)
|
3485 |
|
3486 |
if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
|
3487 |
-
unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_discard_last_frames"]
|
3488 |
for k in unsaved_params:
|
3489 |
inputs.pop(k)
|
3490 |
|
|
|
|
|
|
|
|
|
3491 |
if target == "metadata":
|
3492 |
inputs = {k: v for k,v in inputs.items() if v != None }
|
3493 |
|
@@ -3511,6 +3577,7 @@ def save_inputs(
|
|
3511 |
seed,
|
3512 |
num_inference_steps,
|
3513 |
guidance_scale,
|
|
|
3514 |
flow_shift,
|
3515 |
embedded_guidance_scale,
|
3516 |
repeat_generation,
|
@@ -3530,8 +3597,10 @@ def save_inputs(
|
|
3530 |
video_guide,
|
3531 |
keep_frames_video_guide,
|
3532 |
video_mask,
|
|
|
3533 |
sliding_window_size,
|
3534 |
sliding_window_overlap,
|
|
|
3535 |
sliding_window_discard_last_frames,
|
3536 |
remove_background_image_ref,
|
3537 |
temporal_upsampling,
|
@@ -3834,6 +3903,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
3834 |
recammaster = "recam" in model_filename
|
3835 |
vace = "Vace" in model_filename
|
3836 |
phantom = "phantom" in model_filename
|
|
|
3837 |
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
|
3838 |
if diffusion_forcing:
|
3839 |
image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
|
@@ -3939,7 +4009,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
3939 |
|
3940 |
|
3941 |
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
|
3942 |
-
|
3943 |
|
3944 |
advanced_prompt = advanced_ui
|
3945 |
prompt_vars=[]
|
@@ -3972,12 +4042,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
3972 |
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
3973 |
wizard_variables_var = gr.Text(wizard_variables, visible = False)
|
3974 |
with gr.Row():
|
3975 |
-
if test_class_i2v(model_filename):
|
3976 |
resolution = gr.Dropdown(
|
3977 |
choices=[
|
3978 |
# 720p
|
3979 |
-
("720p", "1280x720"),
|
3980 |
-
("480p", "832x480"),
|
3981 |
],
|
3982 |
value=ui_defaults.get("resolution","480p"),
|
3983 |
label="Resolution (video will have the same height / width ratio than the original image)"
|
@@ -3989,19 +4059,21 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
3989 |
("1280x720 (16:9, 720p)", "1280x720"),
|
3990 |
("720x1280 (9:16, 720p)", "720x1280"),
|
3991 |
("1024x1024 (4:3, 720p)", "1024x024"),
|
3992 |
-
|
3993 |
-
|
|
|
3994 |
# 480p
|
3995 |
("960x544 (16:9, 540p)", "960x544"),
|
3996 |
("544x960 (16:9, 540p)", "544x960"),
|
3997 |
("832x480 (16:9, 480p)", "832x480"),
|
3998 |
("480x832 (9:16, 480p)", "480x832"),
|
3999 |
-
|
4000 |
-
|
4001 |
-
|
|
|
4002 |
],
|
4003 |
value=ui_defaults.get("resolution","832x480"),
|
4004 |
-
label="Resolution"
|
4005 |
)
|
4006 |
with gr.Row():
|
4007 |
if recammaster:
|
@@ -4010,6 +4082,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
4010 |
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
|
4011 |
elif vace:
|
4012 |
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
|
|
|
|
4013 |
else:
|
4014 |
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
4015 |
with gr.Row():
|
@@ -4029,10 +4103,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
4029 |
choices=[
|
4030 |
("Generate every combination of images and texts", 0),
|
4031 |
("Match images and text prompts", 1),
|
4032 |
-
], visible=
|
4033 |
)
|
4034 |
with gr.Row():
|
4035 |
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
|
|
|
4036 |
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
4037 |
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
|
4038 |
with gr.Row():
|
@@ -4099,7 +4174,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
4099 |
|
4100 |
with gr.Tab("Quality"):
|
4101 |
with gr.Row():
|
4102 |
-
gr.Markdown("<B>
|
4103 |
with gr.Row():
|
4104 |
slg_switch = gr.Dropdown(
|
4105 |
choices=[
|
@@ -4148,11 +4223,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
4148 |
if diffusion_forcing:
|
4149 |
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
|
4150 |
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
4151 |
-
|
|
|
4152 |
else:
|
4153 |
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
|
4154 |
-
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",
|
4155 |
-
|
|
|
4156 |
|
4157 |
|
4158 |
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
|
@@ -4167,8 +4244,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
4167 |
label="RIFLEx positional embedding to generate long video"
|
4168 |
)
|
4169 |
|
4170 |
-
|
4171 |
-
|
4172 |
|
4173 |
if not update_form:
|
4174 |
with gr.Column():
|
@@ -5035,7 +5112,7 @@ def create_demo():
|
|
5035 |
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
5036 |
|
5037 |
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
|
5038 |
-
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.
|
5039 |
global model_list
|
5040 |
|
5041 |
tab_state = gr.State({ "tab_no":0 })
|
@@ -5076,7 +5153,7 @@ def create_demo():
|
|
5076 |
|
5077 |
if __name__ == "__main__":
|
5078 |
atexit.register(autosave_queue)
|
5079 |
-
|
5080 |
# threading.Thread(target=runner, daemon=True).start()
|
5081 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
5082 |
server_port = int(args.server_port)
|
|
|
40 |
AUTOSAVE_FILENAME = "queue.zip"
|
41 |
PROMPT_VARS_MAX = 10
|
42 |
|
43 |
+
target_mmgp_version = "3.4.2"
|
44 |
from importlib.metadata import version
|
45 |
mmgp_version = version("mmgp")
|
46 |
if mmgp_version != target_mmgp_version:
|
|
|
49 |
lock = threading.Lock()
|
50 |
current_task_id = None
|
51 |
task_id = 0
|
52 |
+
|
53 |
+
def download_ffmpeg():
|
54 |
+
if os.name != 'nt': return
|
55 |
+
exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
|
56 |
+
if all(os.path.exists(e) for e in exes): return
|
57 |
+
api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
|
58 |
+
r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
|
59 |
+
assets = r.json().get('assets', [])
|
60 |
+
zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
|
61 |
+
if not zip_asset: return
|
62 |
+
zip_url = zip_asset['browser_download_url']
|
63 |
+
zip_name = zip_asset['name']
|
64 |
+
with requests.get(zip_url, stream=True) as resp:
|
65 |
+
total = int(resp.headers.get('Content-Length', 0))
|
66 |
+
with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
|
67 |
+
for chunk in resp.iter_content(chunk_size=8192):
|
68 |
+
f.write(chunk)
|
69 |
+
pbar.update(len(chunk))
|
70 |
+
with zipfile.ZipFile(zip_name) as z:
|
71 |
+
for f in z.namelist():
|
72 |
+
if f.endswith(tuple(exes)) and '/bin/' in f:
|
73 |
+
z.extract(f)
|
74 |
+
os.rename(f, os.path.basename(f))
|
75 |
+
os.remove(zip_name)
|
|
|
|
|
76 |
|
77 |
def format_time(seconds):
|
78 |
if seconds < 60:
|
|
|
166 |
resolution = inputs["resolution"]
|
167 |
width, height = resolution.split("x")
|
168 |
width, height = int(width), int(height)
|
169 |
+
# if test_class_i2v(model_filename):
|
170 |
+
# if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
|
171 |
+
# gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
172 |
+
# return
|
173 |
+
# resolution = str(width) + "*" + str(height)
|
174 |
+
# if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
175 |
+
# gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
176 |
+
# return
|
177 |
|
178 |
if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
|
179 |
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
|
|
531 |
task_id_s = task.get('id', f"task_{task_index}")
|
532 |
|
533 |
image_keys = ["image_start", "image_end", "image_refs"]
|
534 |
+
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
|
535 |
|
536 |
for key in image_keys:
|
537 |
images_pil = params_copy.get(key)
|
|
|
705 |
params['state'] = state
|
706 |
|
707 |
image_keys = ["image_start", "image_end", "image_refs"]
|
708 |
+
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
|
709 |
|
710 |
loaded_pil_images = {}
|
711 |
loaded_video_paths = {}
|
|
|
923 |
task_id_s = task.get('id', f"task_{task_index}")
|
924 |
|
925 |
image_keys = ["image_start", "image_end", "image_refs"]
|
926 |
+
video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"]
|
927 |
|
928 |
for key in image_keys:
|
929 |
images_pil = params_copy.get(key)
|
|
|
1416 |
text = reader.read()
|
1417 |
server_config = json.loads(text)
|
1418 |
|
1419 |
+
# Deprecated models
|
1420 |
+
for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors",
|
1421 |
+
"sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors",
|
1422 |
+
"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
|
1423 |
+
]:
|
1424 |
+
if Path(os.path.join("ckpts" , path)).is_file():
|
1425 |
+
os.remove( os.path.join("ckpts" , path))
|
1426 |
|
|
|
|
|
|
|
1427 |
|
1428 |
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
|
1429 |
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
|
1430 |
+
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors",
|
1431 |
"ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
|
1432 |
+
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors",
|
1433 |
+
"ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
|
1434 |
+
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
|
1435 |
+
"ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"]
|
1436 |
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
1437 |
+
def get_dependent_models(model_filename, quantization ):
|
1438 |
+
if "fantasy" in model_filename:
|
1439 |
+
return [get_model_filename("i2v_720p", quantization)]
|
1440 |
+
else:
|
1441 |
+
return []
|
1442 |
+
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B", "fantasy"]
|
1443 |
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
1444 |
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
|
1445 |
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
|
1446 |
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
|
1447 |
+
"phantom_1.3B" : "phantom_1.3B", "fantasy" : "fantasy" }
|
1448 |
|
1449 |
|
1450 |
def get_model_type(model_filename):
|
|
|
1454 |
raise Exception("Unknown model:" + model_filename)
|
1455 |
|
1456 |
def test_class_i2v(model_filename):
|
1457 |
+
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename or "fantasy" in model_filename
|
1458 |
|
1459 |
def get_model_name(model_filename, description_container = [""]):
|
1460 |
if "Fun" in model_filename:
|
|
|
1492 |
model_name = "Wan2.1 Phantom"
|
1493 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
1494 |
description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
|
1495 |
+
elif "fantasy" in model_filename:
|
1496 |
+
model_name = "Wan2.1 Fantasy Speaking 720p"
|
1497 |
+
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
1498 |
+
description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input."
|
1499 |
else:
|
1500 |
model_name = "Wan2.1 text2video"
|
1501 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
|
|
1541 |
"repeat_generation": 1,
|
1542 |
"multi_images_gen_type": 0,
|
1543 |
"guidance_scale": 5.0,
|
1544 |
+
"audio_guidance_scale": 5.0,
|
1545 |
"flow_shift": get_default_flow(filename, i2v),
|
1546 |
"negative_prompt": "",
|
1547 |
"activated_loras": [],
|
|
|
1725 |
|
1726 |
from huggingface_hub import hf_hub_download, snapshot_download
|
1727 |
repoId = "DeepBeepMeep/Wan2.1"
|
1728 |
+
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "wav2vec", "" ]
|
1729 |
+
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
|
1730 |
+
["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
1731 |
targetRoot = "ckpts/"
|
1732 |
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
1733 |
if len(files)==0:
|
|
|
1841 |
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
1842 |
|
1843 |
|
1844 |
+
def load_t2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
1845 |
|
1846 |
cfg = WAN_CONFIGS['t2v-14B']
|
1847 |
+
filename = model_filename[-1]
|
1848 |
# cfg = WAN_CONFIGS['t2v-1.3B']
|
1849 |
+
print(f"Loading '{filename}' model...")
|
1850 |
+
if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
|
1851 |
model_factory = wan.DTT2V
|
1852 |
else:
|
1853 |
model_factory = wan.WanT2V
|
|
|
1867 |
|
1868 |
return wan_model, pipe
|
1869 |
|
1870 |
+
def load_i2v_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
1871 |
|
1872 |
+
filename = model_filename[-1]
|
1873 |
+
print(f"Loading '{filename}' model...")
|
1874 |
|
1875 |
cfg = WAN_CONFIGS['i2v-14B']
|
1876 |
wan_model = wan.WanI2V(
|
|
|
1892 |
def load_models(model_filename):
|
1893 |
global transformer_filename
|
1894 |
|
|
|
1895 |
perc_reserved_mem_max = args.perc_reserved_mem_max
|
1896 |
|
1897 |
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
|
|
1900 |
default_dtype = torch.float16
|
1901 |
else:
|
1902 |
default_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
1903 |
+
model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
|
1904 |
+
updated_model_filename = []
|
1905 |
+
for filename in model_filelist:
|
1906 |
+
if default_dtype == torch.float16 :
|
1907 |
+
if "quanto_int8" in filename:
|
1908 |
+
filename = filename.replace("quanto_int8", "quanto_fp16_int8")
|
1909 |
+
elif "quanto_mbf16_int8":
|
1910 |
+
filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
|
1911 |
+
updated_model_filename.append(filename)
|
1912 |
+
download_models(filename, text_encoder_filename)
|
1913 |
+
model_filelist = updated_model_filename
|
1914 |
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
1915 |
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
1916 |
+
transformer_filename = None
|
1917 |
+
new_transformer_filename = model_filelist[-1]
|
1918 |
+
if test_class_i2v(new_transformer_filename):
|
1919 |
+
wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
1920 |
else:
|
1921 |
+
wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
1922 |
+
wan_model._model_file_name = new_transformer_filename
|
1923 |
+
kwargs = { "extraModelsToQuantize": None}
|
1924 |
if profile == 2 or profile == 4:
|
1925 |
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
|
1926 |
# if profile == 4:
|
|
|
1930 |
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
|
1931 |
if len(args.gpu) > 0:
|
1932 |
torch.set_default_device(args.gpu)
|
1933 |
+
transformer_filename = new_transformer_filename
|
1934 |
return wan_model, offloadobj, pipe["transformer"]
|
1935 |
|
1936 |
if not "P" in preload_model_policy:
|
|
|
2049 |
preload_model_policy = server_config["preload_model_policy"]
|
2050 |
transformer_quantization = server_config["transformer_quantization"]
|
2051 |
transformer_types = server_config["transformer_types"]
|
|
|
|
|
2052 |
|
|
|
|
|
|
|
|
|
2053 |
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
|
2054 |
model_choice = gr.Dropdown()
|
2055 |
else:
|
|
|
2259 |
frame_height, frame_width, _ = frames_list[0].shape
|
2260 |
|
2261 |
if fit_canvas :
|
2262 |
+
scale1 = min(height / frame_height, width / frame_width)
|
2263 |
+
scale2 = min(height / frame_width, width / frame_height)
|
2264 |
+
scale = max(scale1, scale2)
|
2265 |
else:
|
2266 |
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
|
2267 |
|
|
|
2368 |
seed,
|
2369 |
num_inference_steps,
|
2370 |
guidance_scale,
|
2371 |
+
audio_guidance_scale,
|
2372 |
flow_shift,
|
2373 |
embedded_guidance_scale,
|
2374 |
repeat_generation,
|
|
|
2388 |
video_guide,
|
2389 |
keep_frames_video_guide,
|
2390 |
video_mask,
|
2391 |
+
audio_guide,
|
2392 |
sliding_window_size,
|
2393 |
sliding_window_overlap,
|
2394 |
+
sliding_window_overlap_noise,
|
2395 |
sliding_window_discard_last_frames,
|
2396 |
remove_background_image_ref,
|
2397 |
temporal_upsampling,
|
|
|
2523 |
# VAE Tiling
|
2524 |
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
2525 |
|
2526 |
+
diffusion_forcing = "diffusion_forcing" in model_filename
|
2527 |
+
vace = "Vace" in model_filename
|
2528 |
+
if diffusion_forcing:
|
2529 |
+
fps = 24
|
2530 |
+
elif audio_guide != None:
|
2531 |
+
fps = 23
|
2532 |
+
else:
|
2533 |
+
fps = 16
|
2534 |
+
|
2535 |
joint_pass = boost ==1 #and profile != 1 and profile != 3
|
2536 |
# TeaCache
|
2537 |
trans.enable_teacache = tea_cache_setting > 0
|
|
|
2541 |
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
2542 |
|
2543 |
if image2video:
|
2544 |
+
if '720p' in model_filename:
|
|
|
|
|
2545 |
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
2546 |
else:
|
2547 |
+
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
2548 |
else:
|
2549 |
if '1.3B' in model_filename:
|
2550 |
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
|
|
2557 |
if "recam" in model_filename:
|
2558 |
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
|
2559 |
target_camera = model_mode
|
2560 |
+
|
2561 |
+
audio_proj_split = None
|
2562 |
+
audio_scale = None
|
2563 |
+
audio_context_lens = None
|
2564 |
+
if audio_guide != None:
|
2565 |
+
from fantasytalking.infer import parse_audio
|
2566 |
+
import librosa
|
2567 |
+
duration = librosa.get_duration(path=audio_guide)
|
2568 |
+
video_length = min(int(fps * duration // 4) * 4 + 5, video_length)
|
2569 |
+
audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= video_length, fps= fps, device= processing_device )
|
2570 |
+
audio_scale = 1.0
|
2571 |
+
|
2572 |
import random
|
2573 |
if seed == None or seed <0:
|
2574 |
seed = random.randint(0, 999999999)
|
|
|
2585 |
extra_generation = 0
|
2586 |
initial_total_windows = 0
|
2587 |
max_frames_to_generate = video_length
|
|
|
|
|
2588 |
phantom = "phantom" in model_filename
|
2589 |
if diffusion_forcing or vace:
|
2590 |
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
|
2591 |
+
else:
|
2592 |
+
reuse_frames = 0
|
2593 |
if diffusion_forcing and source_video != None:
|
2594 |
video_length += sliding_window_overlap
|
2595 |
sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
|
|
|
2605 |
initial_total_windows = 1
|
2606 |
|
2607 |
first_window_video_length = video_length
|
|
|
2608 |
|
2609 |
gen["sliding_window"] = sliding_window
|
|
|
2610 |
while not abort:
|
2611 |
extra_generation += gen.get("extra_orders",0)
|
2612 |
gen["extra_orders"] = 0
|
|
|
2626 |
guide_start_frame = 0
|
2627 |
video_length = first_window_video_length
|
2628 |
gen["extra_windows"] = 0
|
2629 |
+
start_time = time.time()
|
2630 |
while not abort:
|
2631 |
if sliding_window:
|
2632 |
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
|
|
|
2675 |
|
2676 |
if preprocess_type != None :
|
2677 |
send_cmd("progress", progress_args)
|
2678 |
+
video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = True, target_fps = fps)
|
2679 |
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
|
2680 |
if len(error) > 0:
|
2681 |
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
|
|
2711 |
trans.teacache_counter = 0
|
2712 |
trans.num_steps = num_inference_steps
|
2713 |
trans.teacache_skipped_steps = 0
|
2714 |
+
trans.previous_residual = None
|
|
|
2715 |
|
2716 |
if image2video:
|
2717 |
samples = wan_model.generate(
|
|
|
2719 |
image_start,
|
2720 |
image_end if image_end != None else None,
|
2721 |
frame_num=(video_length // 4)* 4 + 1,
|
2722 |
+
# max_area=MAX_AREA_CONFIGS[resolution_reformated],
|
2723 |
+
height = height,
|
2724 |
+
width = width,
|
2725 |
shift=flow_shift,
|
2726 |
sampling_steps=num_inference_steps,
|
2727 |
guide_scale=guidance_scale,
|
|
|
2736 |
slg_end = slg_end_perc/100,
|
2737 |
cfg_star_switch = cfg_star_switch,
|
2738 |
cfg_zero_step = cfg_zero_step,
|
2739 |
+
add_frames_for_end_image = "image2video" in model_filename,
|
2740 |
+
audio_cfg_scale= audio_guidance_scale,
|
2741 |
+
audio_proj= audio_proj_split,
|
2742 |
+
audio_scale= audio_scale,
|
2743 |
+
audio_context_lens= audio_context_lens
|
2744 |
)
|
2745 |
elif diffusion_forcing:
|
2746 |
samples = wan_model.generate(
|
|
|
2758 |
callback= callback,
|
2759 |
VAE_tile_size = VAE_tile_size,
|
2760 |
joint_pass = joint_pass,
|
2761 |
+
slg_layers = slg_layers,
|
2762 |
+
slg_start = slg_start_perc/100,
|
2763 |
+
slg_end = slg_end_perc/100,
|
2764 |
+
addnoise_condition = sliding_window_overlap_noise,
|
2765 |
ar_step = model_mode, #5
|
2766 |
causal_block_size = 5,
|
2767 |
causal_attention = True,
|
2768 |
fps = fps,
|
2769 |
)
|
2770 |
else:
|
2771 |
+
samples = wan_model.generate(
|
2772 |
prompt,
|
2773 |
input_frames = src_video,
|
2774 |
input_ref_images= src_ref_images,
|
|
|
2792 |
slg_end = slg_end_perc/100,
|
2793 |
cfg_star_switch = cfg_star_switch,
|
2794 |
cfg_zero_step = cfg_zero_step,
|
2795 |
+
overlapped_latents = 0 if reuse_frames == 0 or window_no == 1 else ((reuse_frames - 1) // 4 + 1),
|
2796 |
+
overlap_noise = sliding_window_overlap_noise,
|
2797 |
+
vace = vace
|
2798 |
)
|
2799 |
except Exception as e:
|
2800 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
|
|
2826 |
print('\n'.join(tb))
|
2827 |
send_cmd("error", new_error)
|
2828 |
return
|
2829 |
+
finally:
|
2830 |
+
trans.previous_residual = None
|
2831 |
|
2832 |
if trans.enable_teacache:
|
2833 |
+
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{trans.num_steps}" )
|
|
|
|
|
2834 |
|
2835 |
if samples != None:
|
2836 |
samples = samples.to("cpu")
|
|
|
2854 |
if discard_last_frames > 0:
|
2855 |
sample = sample[: , :-discard_last_frames]
|
2856 |
guide_start_frame -= discard_last_frames
|
2857 |
+
if reuse_frames == 0:
|
2858 |
+
pre_video_guide = sample[:,9999 :]
|
2859 |
+
else:
|
2860 |
+
# noise_factor = 200/ 1000
|
2861 |
+
# pre_video_guide = sample[:, -reuse_frames:] * (1.0 - noise_factor) + torch.randn_like(sample[:, -reuse_frames:] ) * noise_factor
|
2862 |
+
pre_video_guide = sample[:, -reuse_frames:]
|
2863 |
+
|
2864 |
+
|
2865 |
if prefix_video != None:
|
2866 |
+
if reuse_frames == 0:
|
2867 |
+
sample = torch.cat([ prefix_video[:, :], sample], dim = 1)
|
2868 |
+
else:
|
2869 |
+
sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1)
|
2870 |
prefix_video = None
|
2871 |
if sliding_window and window_no > 1:
|
2872 |
+
if reuse_frames == 0:
|
2873 |
+
sample = sample[: , :]
|
2874 |
+
else:
|
2875 |
+
sample = sample[: , reuse_frames:]
|
2876 |
|
2877 |
+
guide_start_frame -= reuse_frames
|
2878 |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
2879 |
if os.name == 'nt':
|
2880 |
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
|
|
2932 |
sample = torch.cat([frames_already_processed, sample], dim=1)
|
2933 |
frames_already_processed = sample
|
2934 |
|
2935 |
+
if audio_guide == None:
|
2936 |
+
cache_video( tensor=sample[None], save_file=video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
|
2937 |
+
else:
|
2938 |
+
save_path_tmp = video_path[:-4] + "_tmp.mp4"
|
2939 |
+
cache_video( tensor=sample[None], save_file=save_path_tmp, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
|
2940 |
+
final_command = [ "ffmpeg", "-y", "-i", save_path_tmp, "-i", audio_guide, "-c:v", "libx264", "-c:a", "aac", "-shortest", "-loglevel", "warning", "-nostats", video_path, ]
|
2941 |
+
import subprocess
|
2942 |
+
subprocess.run(final_command, check=True)
|
2943 |
+
os.remove(save_path_tmp)
|
2944 |
+
|
2945 |
+
end_time = time.time()
|
2946 |
|
2947 |
inputs = get_function_arguments(generate_video, locals())
|
2948 |
inputs.pop("send_cmd")
|
2949 |
+
inputs.pop("task_id")
|
2950 |
configs = prepare_inputs_dict("metadata", inputs)
|
2951 |
+
configs["generation_time"] = round(end_time-start_time)
|
2952 |
metadata_choice = server_config.get("metadata_type","metadata")
|
2953 |
if metadata_choice == "json":
|
2954 |
with open(video_path.replace('.mp4', '.json'), 'w') as f:
|
|
|
3175 |
prompt_no = gen["prompt_no"]
|
3176 |
prompts_max = gen.get("prompts_max",0)
|
3177 |
total_generation = gen.get("total_generation", 1)
|
3178 |
+
repeat_no = gen.get("repeat_no",0)
|
3179 |
total_generation += gen.get("extra_orders", 0)
|
3180 |
total_windows = gen.get("total_windows", 0)
|
3181 |
total_windows += gen.get("extra_windows", 0)
|
|
|
3518 |
|
3519 |
if target == "state":
|
3520 |
return inputs
|
3521 |
+
unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask", "audio_guide", "embedded_guidance_scale"]
|
3522 |
for k in unsaved_params:
|
3523 |
inputs.pop(k)
|
3524 |
|
|
|
3546 |
inputs.pop(k)
|
3547 |
|
3548 |
if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
|
3549 |
+
unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]
|
3550 |
for k in unsaved_params:
|
3551 |
inputs.pop(k)
|
3552 |
|
3553 |
+
if not "fantasy" in model_filename:
|
3554 |
+
inputs.pop("audio_guidance_scale")
|
3555 |
+
|
3556 |
+
|
3557 |
if target == "metadata":
|
3558 |
inputs = {k: v for k,v in inputs.items() if v != None }
|
3559 |
|
|
|
3577 |
seed,
|
3578 |
num_inference_steps,
|
3579 |
guidance_scale,
|
3580 |
+
audio_guidance_scale,
|
3581 |
flow_shift,
|
3582 |
embedded_guidance_scale,
|
3583 |
repeat_generation,
|
|
|
3597 |
video_guide,
|
3598 |
keep_frames_video_guide,
|
3599 |
video_mask,
|
3600 |
+
audio_guide,
|
3601 |
sliding_window_size,
|
3602 |
sliding_window_overlap,
|
3603 |
+
sliding_window_overlap_noise,
|
3604 |
sliding_window_discard_last_frames,
|
3605 |
remove_background_image_ref,
|
3606 |
temporal_upsampling,
|
|
|
3903 |
recammaster = "recam" in model_filename
|
3904 |
vace = "Vace" in model_filename
|
3905 |
phantom = "phantom" in model_filename
|
3906 |
+
fantasy = "fantasy" in model_filename
|
3907 |
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
|
3908 |
if diffusion_forcing:
|
3909 |
image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
|
|
|
4009 |
|
4010 |
|
4011 |
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
|
4012 |
+
audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy )
|
4013 |
|
4014 |
advanced_prompt = advanced_ui
|
4015 |
prompt_vars=[]
|
|
|
4042 |
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
4043 |
wizard_variables_var = gr.Text(wizard_variables, visible = False)
|
4044 |
with gr.Row():
|
4045 |
+
if test_class_i2v(model_filename) and False:
|
4046 |
resolution = gr.Dropdown(
|
4047 |
choices=[
|
4048 |
# 720p
|
4049 |
+
("720p (same amount of pixels)", "1280x720"),
|
4050 |
+
("480p (same amount of pixels)", "832x480"),
|
4051 |
],
|
4052 |
value=ui_defaults.get("resolution","480p"),
|
4053 |
label="Resolution (video will have the same height / width ratio than the original image)"
|
|
|
4059 |
("1280x720 (16:9, 720p)", "1280x720"),
|
4060 |
("720x1280 (9:16, 720p)", "720x1280"),
|
4061 |
("1024x1024 (4:3, 720p)", "1024x024"),
|
4062 |
+
("832x1104 (3:4, 720p)", "832x1104"),
|
4063 |
+
("1104x832 (3:4, 720p)", "1104x832"),
|
4064 |
+
("960x960 (1:1, 720p)", "960x960"),
|
4065 |
# 480p
|
4066 |
("960x544 (16:9, 540p)", "960x544"),
|
4067 |
("544x960 (16:9, 540p)", "544x960"),
|
4068 |
("832x480 (16:9, 480p)", "832x480"),
|
4069 |
("480x832 (9:16, 480p)", "480x832"),
|
4070 |
+
("832x624 (4:3, 480p)", "832x624"),
|
4071 |
+
("624x832 (3:4, 480p)", "624x832"),
|
4072 |
+
("720x720 (1:1, 480p)", "720x720"),
|
4073 |
+
("512x512 (1:1, 480p)", "512x512"),
|
4074 |
],
|
4075 |
value=ui_defaults.get("resolution","832x480"),
|
4076 |
+
label="Max Resolution (as it maybe less depending on video width / height ratio)" if test_class_i2v(model_filename) else "Resolution"
|
4077 |
)
|
4078 |
with gr.Row():
|
4079 |
if recammaster:
|
|
|
4082 |
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
|
4083 |
elif vace:
|
4084 |
video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
4085 |
+
elif fantasy:
|
4086 |
+
video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True)
|
4087 |
else:
|
4088 |
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
|
4089 |
with gr.Row():
|
|
|
4103 |
choices=[
|
4104 |
("Generate every combination of images and texts", 0),
|
4105 |
("Match images and text prompts", 1),
|
4106 |
+
], visible= test_class_i2v(model_filename), label= "Multiple Images as Texts Prompts"
|
4107 |
)
|
4108 |
with gr.Row():
|
4109 |
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
|
4110 |
+
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale",5), step=0.5, label="Audio Guidance", visible=fantasy)
|
4111 |
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
4112 |
flow_shift = gr.Slider(0.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale")
|
4113 |
with gr.Row():
|
|
|
4174 |
|
4175 |
with gr.Tab("Quality"):
|
4176 |
with gr.Row():
|
4177 |
+
gr.Markdown("<B>Skip Layer Guidance (improves video quality)</B>")
|
4178 |
with gr.Row():
|
4179 |
slg_switch = gr.Dropdown(
|
4180 |
choices=[
|
|
|
4223 |
if diffusion_forcing:
|
4224 |
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
|
4225 |
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
4226 |
+
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
|
4227 |
+
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
|
4228 |
else:
|
4229 |
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
|
4230 |
+
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
4231 |
+
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect")
|
4232 |
+
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
|
4233 |
|
4234 |
|
4235 |
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
|
|
|
4244 |
label="RIFLEx positional embedding to generate long video"
|
4245 |
)
|
4246 |
|
4247 |
+
with gr.Row():
|
4248 |
+
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
|
4249 |
|
4250 |
if not update_form:
|
4251 |
with gr.Column():
|
|
|
5112 |
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
5113 |
|
5114 |
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
|
5115 |
+
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.5 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
5116 |
global model_list
|
5117 |
|
5118 |
tab_state = gr.State({ "tab_no":0 })
|
|
|
5153 |
|
5154 |
if __name__ == "__main__":
|
5155 |
atexit.register(autosave_queue)
|
5156 |
+
download_ffmpeg()
|
5157 |
# threading.Thread(target=runner, daemon=True).start()
|
5158 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
5159 |
server_port = int(args.server_port)
|