DeepBeepMeep commited on
Commit
d589426
·
1 Parent(s): 6d57bc7

Various Memory Optimisations

Browse files
fantasytalking/infer.py CHANGED
@@ -4,24 +4,33 @@ from transformers import Wav2Vec2Model, Wav2Vec2Processor
4
 
5
  from .model import FantasyTalkingAudioConditionModel
6
  from .utils import get_audio_features
7
-
8
 
9
  def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
10
  fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
11
  from mmgp import offload
12
  from accelerate import init_empty_weights
13
  from fantasytalking.model import AudioProjModel
 
 
 
14
  with init_empty_weights():
15
  proj_model = AudioProjModel( 768, 2048)
16
  offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
17
- proj_model.to(device).eval().requires_grad_(False)
18
 
19
  wav2vec_model_dir = "ckpts/wav2vec"
20
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
21
- wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False)
 
 
22
  audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
23
 
24
  audio_proj_fea = proj_model(audio_wav2vec_fea)
25
  pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
26
- audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
 
 
 
 
27
  return audio_proj_split, audio_context_lens
 
4
 
5
  from .model import FantasyTalkingAudioConditionModel
6
  from .utils import get_audio_features
7
+ import gc, torch
8
 
9
  def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
10
  fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
11
  from mmgp import offload
12
  from accelerate import init_empty_weights
13
  from fantasytalking.model import AudioProjModel
14
+
15
+ torch.set_grad_enabled(False)
16
+
17
  with init_empty_weights():
18
  proj_model = AudioProjModel( 768, 2048)
19
  offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
20
+ proj_model.to("cpu").eval().requires_grad_(False)
21
 
22
  wav2vec_model_dir = "ckpts/wav2vec"
23
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
24
+ wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False)
25
+ wav2vec.to(device)
26
+ proj_model.to(device)
27
  audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
28
 
29
  audio_proj_fea = proj_model(audio_wav2vec_fea)
30
  pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
31
+ audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
32
+ wav2vec, proj_model= None, None
33
+ gc.collect()
34
+ torch.cuda.empty_cache()
35
+
36
  return audio_proj_split, audio_context_lens
requirements.txt CHANGED
@@ -16,7 +16,7 @@ gradio==5.23.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
- mmgp==3.4.2
20
  peft==0.14.0
21
  mutagen
22
  pydantic==2.10.6
 
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
+ mmgp==3.4.3
20
  peft==0.14.0
21
  mutagen
22
  pydantic==2.10.6
wan/image2video.py CHANGED
@@ -103,7 +103,7 @@ class WanI2V:
103
  # dtype = torch.float16
104
  self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
105
  self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
106
- # offload.change_dtype(self.model, dtype, True)
107
  # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
108
  # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
109
  # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
@@ -403,9 +403,7 @@ class WanI2V:
403
  if callback is not None:
404
  callback(i, latent, False)
405
 
406
- x0 = [latent]
407
-
408
- # x0 = [lat_y]
409
  video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
410
 
411
  if any_end_frame and add_frames_for_end_image:
 
103
  # dtype = torch.float16
104
  self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
105
  self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
106
+ offload.change_dtype(self.model, dtype, True)
107
  # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
108
  # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
109
  # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
 
403
  if callback is not None:
404
  callback(i, latent, False)
405
 
406
+ x0 = [latent]
 
 
407
  video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
408
 
409
  if any_end_frame and add_frames_for_end_image:
wan/modules/model.py CHANGED
@@ -312,8 +312,6 @@ class WanI2VCrossAttention(WanSelfAttention):
312
  del x
313
  self.norm_q(q)
314
  q= q.view(b, -1, n, d)
315
- if audio_scale != None:
316
- audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
317
  k = self.k(context)
318
  self.norm_k(k)
319
  k = k.view(b, -1, n, d)
@@ -323,6 +321,8 @@ class WanI2VCrossAttention(WanSelfAttention):
323
  del k,v
324
  x = pay_attention(qkv_list)
325
 
 
 
326
  k_img = self.k_img(context_img)
327
  self.norm_k_img(k_img)
328
  k_img = k_img.view(b, -1, n, d)
 
312
  del x
313
  self.norm_q(q)
314
  q= q.view(b, -1, n, d)
 
 
315
  k = self.k(context)
316
  self.norm_k(k)
317
  k = k.view(b, -1, n, d)
 
321
  del k,v
322
  x = pay_attention(qkv_list)
323
 
324
+ if audio_scale != None:
325
+ audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
326
  k_img = self.k_img(context_img)
327
  self.norm_k_img(k_img)
328
  k_img = k_img.view(b, -1, n, d)
wgp.py CHANGED
@@ -40,7 +40,7 @@ global_queue_ref = []
40
  AUTOSAVE_FILENAME = "queue.zip"
41
  PROMPT_VARS_MAX = 10
42
 
43
- target_mmgp_version = "3.4.2"
44
  from importlib.metadata import version
45
  mmgp_version = version("mmgp")
46
  if mmgp_version != target_mmgp_version:
@@ -50,6 +50,7 @@ lock = threading.Lock()
50
  current_task_id = None
51
  task_id = 0
52
 
 
53
  def download_ffmpeg():
54
  if os.name != 'nt': return
55
  exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
@@ -1421,6 +1422,7 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion
1421
  "wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
1422
  ]:
1423
  if Path(os.path.join("ckpts" , path)).is_file():
 
1424
  os.remove( os.path.join("ckpts" , path))
1425
 
1426
 
@@ -1511,14 +1513,21 @@ def get_model_filename(model_type, quantization):
1511
  quantization = "bf16"
1512
 
1513
  if len(choices) <= 1:
1514
- return choices[0]
1515
-
1516
- sub_choices = [ name for name in choices if quantization in name]
1517
- if len(sub_choices) > 0:
1518
- return sub_choices[0]
1519
  else:
1520
- return choices[0]
 
 
 
 
1521
 
 
 
 
 
 
 
 
1522
  def get_settings_file_name(model_filename):
1523
  return os.path.join(args.settings, get_model_type(model_filename) + "_settings.json")
1524
 
@@ -1599,6 +1608,13 @@ def get_default_settings(filename):
1599
  ui_defaults["num_inference_steps"] = default_number_steps
1600
  return ui_defaults
1601
 
 
 
 
 
 
 
 
1602
  transformer_types = server_config.get("transformer_types", [])
1603
  transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
1604
  transformer_quantization =server_config.get("transformer_quantization", "int8")
@@ -1892,32 +1908,17 @@ def load_models(model_filename):
1892
  global transformer_filename
1893
 
1894
  perc_reserved_mem_max = args.perc_reserved_mem_max
1895
-
1896
- major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
1897
- if major < 8:
1898
- print("Switching to f16 model as GPU architecture doesn't support bf16")
1899
- default_dtype = torch.float16
1900
- else:
1901
- default_dtype = torch.float16 if args.fp16 else torch.bfloat16
1902
  model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
1903
- updated_model_filename = []
1904
  for filename in model_filelist:
1905
- if default_dtype == torch.float16 :
1906
- if "quanto_int8" in filename:
1907
- filename = filename.replace("quanto_int8", "quanto_fp16_int8")
1908
- elif "quanto_mbf16_int8":
1909
- filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
1910
- updated_model_filename.append(filename)
1911
  download_models(filename, text_encoder_filename)
1912
- model_filelist = updated_model_filename
1913
  VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
1914
  mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
1915
  transformer_filename = None
1916
  new_transformer_filename = model_filelist[-1]
1917
  if test_class_i2v(new_transformer_filename):
1918
- wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1919
  else:
1920
- wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1921
  wan_model._model_file_name = new_transformer_filename
1922
  kwargs = { "extraModelsToQuantize": None}
1923
  if profile == 2 or profile == 4:
@@ -1926,7 +1927,7 @@ def load_models(model_filename):
1926
  # kwargs["partialPinning"] = True
1927
  elif profile == 3:
1928
  kwargs["budgets"] = { "*" : "70%" }
1929
- offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)
1930
  if len(args.gpu) > 0:
1931
  torch.set_default_device(args.gpu)
1932
  transformer_filename = new_transformer_filename
@@ -2410,6 +2411,7 @@ def generate_video(
2410
  ):
2411
  global wan_model, offloadobj, reload_needed
2412
  gen = get_gen_info(state)
 
2413
 
2414
  file_list = gen["file_list"]
2415
  prompt_no = gen["prompt_no"]
@@ -2574,6 +2576,7 @@ def generate_video(
2574
  if seed == None or seed <0:
2575
  seed = random.randint(0, 999999999)
2576
 
 
2577
  global save_path
2578
  os.makedirs(save_path, exist_ok=True)
2579
  abort = False
 
40
  AUTOSAVE_FILENAME = "queue.zip"
41
  PROMPT_VARS_MAX = 10
42
 
43
+ target_mmgp_version = "3.4.3"
44
  from importlib.metadata import version
45
  mmgp_version = version("mmgp")
46
  if mmgp_version != target_mmgp_version:
 
50
  current_task_id = None
51
  task_id = 0
52
 
53
+
54
  def download_ffmpeg():
55
  if os.name != 'nt': return
56
  exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
 
1422
  "wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
1423
  ]:
1424
  if Path(os.path.join("ckpts" , path)).is_file():
1425
+ print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.")
1426
  os.remove( os.path.join("ckpts" , path))
1427
 
1428
 
 
1513
  quantization = "bf16"
1514
 
1515
  if len(choices) <= 1:
1516
+ raw_filename = choices[0]
 
 
 
 
1517
  else:
1518
+ sub_choices = [ name for name in choices if quantization in name]
1519
+ if len(sub_choices) > 0:
1520
+ raw_filename = sub_choices[0]
1521
+ else:
1522
+ raw_filename = choices[0]
1523
 
1524
+ if transformer_dtype == torch.float16 :
1525
+ if "quanto_int8" in raw_filename:
1526
+ raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
1527
+ elif "quanto_mbf16_int8":
1528
+ raw_filename= raw_filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
1529
+ return raw_filename
1530
+
1531
  def get_settings_file_name(model_filename):
1532
  return os.path.join(args.settings, get_model_type(model_filename) + "_settings.json")
1533
 
 
1608
  ui_defaults["num_inference_steps"] = default_number_steps
1609
  return ui_defaults
1610
 
1611
+ major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
1612
+ if major < 8:
1613
+ print("Switching to f16 models as GPU architecture doesn't support bf16")
1614
+ transformer_dtype = torch.float16
1615
+ else:
1616
+ transformer_dtype = torch.float16 if args.fp16 else torch.bfloat16
1617
+
1618
  transformer_types = server_config.get("transformer_types", [])
1619
  transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
1620
  transformer_quantization =server_config.get("transformer_quantization", "int8")
 
1908
  global transformer_filename
1909
 
1910
  perc_reserved_mem_max = args.perc_reserved_mem_max
 
 
 
 
 
 
 
1911
  model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
 
1912
  for filename in model_filelist:
 
 
 
 
 
 
1913
  download_models(filename, text_encoder_filename)
 
1914
  VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
1915
  mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
1916
  transformer_filename = None
1917
  new_transformer_filename = model_filelist[-1]
1918
  if test_class_i2v(new_transformer_filename):
1919
+ wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1920
  else:
1921
+ wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1922
  wan_model._model_file_name = new_transformer_filename
1923
  kwargs = { "extraModelsToQuantize": None}
1924
  if profile == 2 or profile == 4:
 
1927
  # kwargs["partialPinning"] = True
1928
  elif profile == 3:
1929
  kwargs["budgets"] = { "*" : "70%" }
1930
+ offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
1931
  if len(args.gpu) > 0:
1932
  torch.set_default_device(args.gpu)
1933
  transformer_filename = new_transformer_filename
 
2411
  ):
2412
  global wan_model, offloadobj, reload_needed
2413
  gen = get_gen_info(state)
2414
+ torch.set_grad_enabled(False)
2415
 
2416
  file_list = gen["file_list"]
2417
  prompt_no = gen["prompt_no"]
 
2576
  if seed == None or seed <0:
2577
  seed = random.randint(0, 999999999)
2578
 
2579
+ torch.set_grad_enabled(False)
2580
  global save_path
2581
  os.makedirs(save_path, exist_ok=True)
2582
  abort = False