DeepBeepMeep commited on
Commit
03085c8
·
1 Parent(s): e420cd0

optimization for i2v with CausVid

Browse files
hyvideo/modules/models.py CHANGED
@@ -492,8 +492,7 @@ class MMSingleStreamBlock(nn.Module):
492
  return img, txt
493
 
494
  class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
495
- @staticmethod
496
- def preprocess_loras(model_filename, sd):
497
  if not "i2v" in model_filename:
498
  return sd
499
  new_sd = {}
 
492
  return img, txt
493
 
494
  class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
495
+ def preprocess_loras(self, model_filename, sd):
 
496
  if not "i2v" in model_filename:
497
  return sd
498
  new_sd = {}
wan/image2video.py CHANGED
@@ -330,8 +330,11 @@ class WanI2V:
330
  'current_step' :i,
331
  })
332
 
333
-
334
- if joint_pass:
 
 
 
335
  if audio_proj == None:
336
  noise_pred_cond, noise_pred_uncond = self.model(
337
  [latent_model_input, latent_model_input],
@@ -347,13 +350,7 @@ class WanI2V:
347
  if self._interrupt:
348
  return None
349
  else:
350
- noise_pred_cond = self.model(
351
- [latent_model_input],
352
- context=[context],
353
- audio_scale = None if audio_scale == None else [audio_scale],
354
- x_id=0,
355
- **kwargs,
356
- )[0]
357
  if self._interrupt:
358
  return None
359
 
@@ -377,22 +374,24 @@ class WanI2V:
377
  return None
378
  del latent_model_input
379
 
380
- # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
381
- if cfg_star_switch:
382
- positive_flat = noise_pred_cond.view(batch_size, -1)
383
- negative_flat = noise_pred_uncond.view(batch_size, -1)
384
-
385
- alpha = optimized_scale(positive_flat,negative_flat)
386
- alpha = alpha.view(batch_size, 1, 1, 1)
387
-
388
- if (i <= cfg_zero_step):
389
- noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
 
 
 
 
 
390
  else:
391
- noise_pred_uncond *= alpha
392
- if audio_scale == None:
393
- noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
394
- else:
395
- noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
396
  noise_pred_uncond, noise_pred_noaudio = None, None
397
  temp_x0 = sample_scheduler.step(
398
  noise_pred.unsqueeze(0),
 
330
  'current_step' :i,
331
  })
332
 
333
+ if guide_scale == 1:
334
+ noise_pred = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
335
+ if self._interrupt:
336
+ return None
337
+ elif joint_pass:
338
  if audio_proj == None:
339
  noise_pred_cond, noise_pred_uncond = self.model(
340
  [latent_model_input, latent_model_input],
 
350
  if self._interrupt:
351
  return None
352
  else:
353
+ noise_pred_cond = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
 
 
 
 
 
 
354
  if self._interrupt:
355
  return None
356
 
 
374
  return None
375
  del latent_model_input
376
 
377
+ if guide_scale > 1:
378
+ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
379
+ if cfg_star_switch:
380
+ positive_flat = noise_pred_cond.view(batch_size, -1)
381
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
382
+
383
+ alpha = optimized_scale(positive_flat,negative_flat)
384
+ alpha = alpha.view(batch_size, 1, 1, 1)
385
+
386
+ if (i <= cfg_zero_step):
387
+ noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
388
+ else:
389
+ noise_pred_uncond *= alpha
390
+ if audio_scale == None:
391
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
392
  else:
393
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
394
+
 
 
 
395
  noise_pred_uncond, noise_pred_noaudio = None, None
396
  temp_x0 = sample_scheduler.step(
397
  noise_pred.unsqueeze(0),
wan/modules/model.py CHANGED
@@ -589,8 +589,7 @@ class MLPProj(torch.nn.Module):
589
 
590
 
591
  class WanModel(ModelMixin, ConfigMixin):
592
- @staticmethod
593
- def preprocess_loras(model_filename, sd):
594
 
595
  first = next(iter(sd), None)
596
  if first == None:
@@ -634,8 +633,8 @@ class WanModel(ModelMixin, ConfigMixin):
634
  print(f"Lora alpha'{alpha_key}' is missing")
635
  new_sd.update(new_alphas)
636
  sd = new_sd
637
-
638
- if "text2video" in model_filename:
639
  new_sd = {}
640
  # convert loras for i2v to t2v
641
  for k,v in sd.items():
 
589
 
590
 
591
  class WanModel(ModelMixin, ConfigMixin):
592
+ def preprocess_loras(self, model_filename, sd):
 
593
 
594
  first = next(iter(sd), None)
595
  if first == None:
 
633
  print(f"Lora alpha'{alpha_key}' is missing")
634
  new_sd.update(new_alphas)
635
  sd = new_sd
636
+ from wgp import test_class_i2v
637
+ if not test_class_i2v(model_filename):
638
  new_sd = {}
639
  # convert loras for i2v to t2v
640
  for k,v in sd.items():