Spaces:
Running
on
T4
Running
on
T4
DeepBeepMeep
commited on
Commit
·
d589426
1
Parent(s):
6d57bc7
Various Memory Optimisations
Browse files- fantasytalking/infer.py +13 -4
- requirements.txt +1 -1
- wan/image2video.py +2 -4
- wan/modules/model.py +2 -2
- wgp.py +28 -25
fantasytalking/infer.py
CHANGED
@@ -4,24 +4,33 @@ from transformers import Wav2Vec2Model, Wav2Vec2Processor
|
|
4 |
|
5 |
from .model import FantasyTalkingAudioConditionModel
|
6 |
from .utils import get_audio_features
|
7 |
-
|
8 |
|
9 |
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
|
10 |
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
|
11 |
from mmgp import offload
|
12 |
from accelerate import init_empty_weights
|
13 |
from fantasytalking.model import AudioProjModel
|
|
|
|
|
|
|
14 |
with init_empty_weights():
|
15 |
proj_model = AudioProjModel( 768, 2048)
|
16 |
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
|
17 |
-
proj_model.to(
|
18 |
|
19 |
wav2vec_model_dir = "ckpts/wav2vec"
|
20 |
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
|
21 |
-
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).
|
|
|
|
|
22 |
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
|
23 |
|
24 |
audio_proj_fea = proj_model(audio_wav2vec_fea)
|
25 |
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
|
26 |
-
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
|
|
|
|
|
|
|
|
|
27 |
return audio_proj_split, audio_context_lens
|
|
|
4 |
|
5 |
from .model import FantasyTalkingAudioConditionModel
|
6 |
from .utils import get_audio_features
|
7 |
+
import gc, torch
|
8 |
|
9 |
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
|
10 |
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
|
11 |
from mmgp import offload
|
12 |
from accelerate import init_empty_weights
|
13 |
from fantasytalking.model import AudioProjModel
|
14 |
+
|
15 |
+
torch.set_grad_enabled(False)
|
16 |
+
|
17 |
with init_empty_weights():
|
18 |
proj_model = AudioProjModel( 768, 2048)
|
19 |
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
|
20 |
+
proj_model.to("cpu").eval().requires_grad_(False)
|
21 |
|
22 |
wav2vec_model_dir = "ckpts/wav2vec"
|
23 |
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
|
24 |
+
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False)
|
25 |
+
wav2vec.to(device)
|
26 |
+
proj_model.to(device)
|
27 |
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
|
28 |
|
29 |
audio_proj_fea = proj_model(audio_wav2vec_fea)
|
30 |
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
|
31 |
+
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
|
32 |
+
wav2vec, proj_model= None, None
|
33 |
+
gc.collect()
|
34 |
+
torch.cuda.empty_cache()
|
35 |
+
|
36 |
return audio_proj_split, audio_context_lens
|
requirements.txt
CHANGED
@@ -16,7 +16,7 @@ gradio==5.23.0
|
|
16 |
numpy>=1.23.5,<2
|
17 |
einops
|
18 |
moviepy==1.0.3
|
19 |
-
mmgp==3.4.
|
20 |
peft==0.14.0
|
21 |
mutagen
|
22 |
pydantic==2.10.6
|
|
|
16 |
numpy>=1.23.5,<2
|
17 |
einops
|
18 |
moviepy==1.0.3
|
19 |
+
mmgp==3.4.3
|
20 |
peft==0.14.0
|
21 |
mutagen
|
22 |
pydantic==2.10.6
|
wan/image2video.py
CHANGED
@@ -103,7 +103,7 @@ class WanI2V:
|
|
103 |
# dtype = torch.float16
|
104 |
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
|
105 |
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
106 |
-
|
107 |
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
|
108 |
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
109 |
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
@@ -403,9 +403,7 @@ class WanI2V:
|
|
403 |
if callback is not None:
|
404 |
callback(i, latent, False)
|
405 |
|
406 |
-
x0 = [latent]
|
407 |
-
|
408 |
-
# x0 = [lat_y]
|
409 |
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
410 |
|
411 |
if any_end_frame and add_frames_for_end_image:
|
|
|
103 |
# dtype = torch.float16
|
104 |
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
|
105 |
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
106 |
+
offload.change_dtype(self.model, dtype, True)
|
107 |
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
|
108 |
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
109 |
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
|
|
403 |
if callback is not None:
|
404 |
callback(i, latent, False)
|
405 |
|
406 |
+
x0 = [latent]
|
|
|
|
|
407 |
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
408 |
|
409 |
if any_end_frame and add_frames_for_end_image:
|
wan/modules/model.py
CHANGED
@@ -312,8 +312,6 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|
312 |
del x
|
313 |
self.norm_q(q)
|
314 |
q= q.view(b, -1, n, d)
|
315 |
-
if audio_scale != None:
|
316 |
-
audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
|
317 |
k = self.k(context)
|
318 |
self.norm_k(k)
|
319 |
k = k.view(b, -1, n, d)
|
@@ -323,6 +321,8 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|
323 |
del k,v
|
324 |
x = pay_attention(qkv_list)
|
325 |
|
|
|
|
|
326 |
k_img = self.k_img(context_img)
|
327 |
self.norm_k_img(k_img)
|
328 |
k_img = k_img.view(b, -1, n, d)
|
|
|
312 |
del x
|
313 |
self.norm_q(q)
|
314 |
q= q.view(b, -1, n, d)
|
|
|
|
|
315 |
k = self.k(context)
|
316 |
self.norm_k(k)
|
317 |
k = k.view(b, -1, n, d)
|
|
|
321 |
del k,v
|
322 |
x = pay_attention(qkv_list)
|
323 |
|
324 |
+
if audio_scale != None:
|
325 |
+
audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
|
326 |
k_img = self.k_img(context_img)
|
327 |
self.norm_k_img(k_img)
|
328 |
k_img = k_img.view(b, -1, n, d)
|
wgp.py
CHANGED
@@ -40,7 +40,7 @@ global_queue_ref = []
|
|
40 |
AUTOSAVE_FILENAME = "queue.zip"
|
41 |
PROMPT_VARS_MAX = 10
|
42 |
|
43 |
-
target_mmgp_version = "3.4.
|
44 |
from importlib.metadata import version
|
45 |
mmgp_version = version("mmgp")
|
46 |
if mmgp_version != target_mmgp_version:
|
@@ -50,6 +50,7 @@ lock = threading.Lock()
|
|
50 |
current_task_id = None
|
51 |
task_id = 0
|
52 |
|
|
|
53 |
def download_ffmpeg():
|
54 |
if os.name != 'nt': return
|
55 |
exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
|
@@ -1421,6 +1422,7 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion
|
|
1421 |
"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
|
1422 |
]:
|
1423 |
if Path(os.path.join("ckpts" , path)).is_file():
|
|
|
1424 |
os.remove( os.path.join("ckpts" , path))
|
1425 |
|
1426 |
|
@@ -1511,14 +1513,21 @@ def get_model_filename(model_type, quantization):
|
|
1511 |
quantization = "bf16"
|
1512 |
|
1513 |
if len(choices) <= 1:
|
1514 |
-
|
1515 |
-
|
1516 |
-
sub_choices = [ name for name in choices if quantization in name]
|
1517 |
-
if len(sub_choices) > 0:
|
1518 |
-
return sub_choices[0]
|
1519 |
else:
|
1520 |
-
|
|
|
|
|
|
|
|
|
1521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1522 |
def get_settings_file_name(model_filename):
|
1523 |
return os.path.join(args.settings, get_model_type(model_filename) + "_settings.json")
|
1524 |
|
@@ -1599,6 +1608,13 @@ def get_default_settings(filename):
|
|
1599 |
ui_defaults["num_inference_steps"] = default_number_steps
|
1600 |
return ui_defaults
|
1601 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1602 |
transformer_types = server_config.get("transformer_types", [])
|
1603 |
transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
|
1604 |
transformer_quantization =server_config.get("transformer_quantization", "int8")
|
@@ -1892,32 +1908,17 @@ def load_models(model_filename):
|
|
1892 |
global transformer_filename
|
1893 |
|
1894 |
perc_reserved_mem_max = args.perc_reserved_mem_max
|
1895 |
-
|
1896 |
-
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
1897 |
-
if major < 8:
|
1898 |
-
print("Switching to f16 model as GPU architecture doesn't support bf16")
|
1899 |
-
default_dtype = torch.float16
|
1900 |
-
else:
|
1901 |
-
default_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
1902 |
model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
|
1903 |
-
updated_model_filename = []
|
1904 |
for filename in model_filelist:
|
1905 |
-
if default_dtype == torch.float16 :
|
1906 |
-
if "quanto_int8" in filename:
|
1907 |
-
filename = filename.replace("quanto_int8", "quanto_fp16_int8")
|
1908 |
-
elif "quanto_mbf16_int8":
|
1909 |
-
filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
|
1910 |
-
updated_model_filename.append(filename)
|
1911 |
download_models(filename, text_encoder_filename)
|
1912 |
-
model_filelist = updated_model_filename
|
1913 |
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
1914 |
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
1915 |
transformer_filename = None
|
1916 |
new_transformer_filename = model_filelist[-1]
|
1917 |
if test_class_i2v(new_transformer_filename):
|
1918 |
-
wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype =
|
1919 |
else:
|
1920 |
-
wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype =
|
1921 |
wan_model._model_file_name = new_transformer_filename
|
1922 |
kwargs = { "extraModelsToQuantize": None}
|
1923 |
if profile == 2 or profile == 4:
|
@@ -1926,7 +1927,7 @@ def load_models(model_filename):
|
|
1926 |
# kwargs["partialPinning"] = True
|
1927 |
elif profile == 3:
|
1928 |
kwargs["budgets"] = { "*" : "70%" }
|
1929 |
-
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo =
|
1930 |
if len(args.gpu) > 0:
|
1931 |
torch.set_default_device(args.gpu)
|
1932 |
transformer_filename = new_transformer_filename
|
@@ -2410,6 +2411,7 @@ def generate_video(
|
|
2410 |
):
|
2411 |
global wan_model, offloadobj, reload_needed
|
2412 |
gen = get_gen_info(state)
|
|
|
2413 |
|
2414 |
file_list = gen["file_list"]
|
2415 |
prompt_no = gen["prompt_no"]
|
@@ -2574,6 +2576,7 @@ def generate_video(
|
|
2574 |
if seed == None or seed <0:
|
2575 |
seed = random.randint(0, 999999999)
|
2576 |
|
|
|
2577 |
global save_path
|
2578 |
os.makedirs(save_path, exist_ok=True)
|
2579 |
abort = False
|
|
|
40 |
AUTOSAVE_FILENAME = "queue.zip"
|
41 |
PROMPT_VARS_MAX = 10
|
42 |
|
43 |
+
target_mmgp_version = "3.4.3"
|
44 |
from importlib.metadata import version
|
45 |
mmgp_version = version("mmgp")
|
46 |
if mmgp_version != target_mmgp_version:
|
|
|
50 |
current_task_id = None
|
51 |
task_id = 0
|
52 |
|
53 |
+
|
54 |
def download_ffmpeg():
|
55 |
if os.name != 'nt': return
|
56 |
exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
|
|
|
1422 |
"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
|
1423 |
]:
|
1424 |
if Path(os.path.join("ckpts" , path)).is_file():
|
1425 |
+
print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.")
|
1426 |
os.remove( os.path.join("ckpts" , path))
|
1427 |
|
1428 |
|
|
|
1513 |
quantization = "bf16"
|
1514 |
|
1515 |
if len(choices) <= 1:
|
1516 |
+
raw_filename = choices[0]
|
|
|
|
|
|
|
|
|
1517 |
else:
|
1518 |
+
sub_choices = [ name for name in choices if quantization in name]
|
1519 |
+
if len(sub_choices) > 0:
|
1520 |
+
raw_filename = sub_choices[0]
|
1521 |
+
else:
|
1522 |
+
raw_filename = choices[0]
|
1523 |
|
1524 |
+
if transformer_dtype == torch.float16 :
|
1525 |
+
if "quanto_int8" in raw_filename:
|
1526 |
+
raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
|
1527 |
+
elif "quanto_mbf16_int8":
|
1528 |
+
raw_filename= raw_filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
|
1529 |
+
return raw_filename
|
1530 |
+
|
1531 |
def get_settings_file_name(model_filename):
|
1532 |
return os.path.join(args.settings, get_model_type(model_filename) + "_settings.json")
|
1533 |
|
|
|
1608 |
ui_defaults["num_inference_steps"] = default_number_steps
|
1609 |
return ui_defaults
|
1610 |
|
1611 |
+
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
1612 |
+
if major < 8:
|
1613 |
+
print("Switching to f16 models as GPU architecture doesn't support bf16")
|
1614 |
+
transformer_dtype = torch.float16
|
1615 |
+
else:
|
1616 |
+
transformer_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
1617 |
+
|
1618 |
transformer_types = server_config.get("transformer_types", [])
|
1619 |
transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
|
1620 |
transformer_quantization =server_config.get("transformer_quantization", "int8")
|
|
|
1908 |
global transformer_filename
|
1909 |
|
1910 |
perc_reserved_mem_max = args.perc_reserved_mem_max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1911 |
model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
|
|
|
1912 |
for filename in model_filelist:
|
|
|
|
|
|
|
|
|
|
|
|
|
1913 |
download_models(filename, text_encoder_filename)
|
|
|
1914 |
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
1915 |
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
1916 |
transformer_filename = None
|
1917 |
new_transformer_filename = model_filelist[-1]
|
1918 |
if test_class_i2v(new_transformer_filename):
|
1919 |
+
wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
1920 |
else:
|
1921 |
+
wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
1922 |
wan_model._model_file_name = new_transformer_filename
|
1923 |
kwargs = { "extraModelsToQuantize": None}
|
1924 |
if profile == 2 or profile == 4:
|
|
|
1927 |
# kwargs["partialPinning"] = True
|
1928 |
elif profile == 3:
|
1929 |
kwargs["budgets"] = { "*" : "70%" }
|
1930 |
+
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
|
1931 |
if len(args.gpu) > 0:
|
1932 |
torch.set_default_device(args.gpu)
|
1933 |
transformer_filename = new_transformer_filename
|
|
|
2411 |
):
|
2412 |
global wan_model, offloadobj, reload_needed
|
2413 |
gen = get_gen_info(state)
|
2414 |
+
torch.set_grad_enabled(False)
|
2415 |
|
2416 |
file_list = gen["file_list"]
|
2417 |
prompt_no = gen["prompt_no"]
|
|
|
2576 |
if seed == None or seed <0:
|
2577 |
seed = random.randint(0, 999999999)
|
2578 |
|
2579 |
+
torch.set_grad_enabled(False)
|
2580 |
global save_path
|
2581 |
os.makedirs(save_path, exist_ok=True)
|
2582 |
abort = False
|