Respair commited on
Commit
19b3a06
·
verified ·
1 Parent(s): 21534d4

Update pkanade_24_multi_gpu_train_finetune_accelerate.py

Browse files
pkanade_24_multi_gpu_train_finetune_accelerate.py CHANGED
@@ -1,5 +1,7 @@
1
  # load packages
2
 
 
 
3
  import random
4
  import yaml
5
  import time
@@ -348,9 +350,9 @@ def main(config_path):
348
  for bib in range(len(mel_input_length)):
349
  mel_length = int(mel_input_length[bib].item())
350
  mel = mels[bib, :, :mel_input_length[bib]]
351
- s = model.predictor_encoder(mel.unsqueeze(0))
352
  ss.append(s)
353
- s = model.style_encoder(mel.unsqueeze(0))
354
  gs.append(s)
355
 
356
  s_dur = torch.stack(ss).squeeze() # global prosodic styles
@@ -432,8 +434,8 @@ def main(config_path):
432
  # if gt.size(-1) < 80:
433
  # continue
434
 
435
- s = model.style_encoder(gt)
436
- s_dur = model.predictor_encoder(gt)
437
 
438
  with torch.no_grad():
439
  F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
@@ -664,9 +666,9 @@ def main(config_path):
664
  for bib in range(len(mel_input_length)):
665
  mel_length = int(mel_input_length[bib].item())
666
  mel = mels[bib, :, :mel_input_length[bib]]
667
- s = model.predictor_encoder(mel.unsqueeze(0))
668
  ss.append(s)
669
- s = model.style_encoder(mel.unsqueeze(0))
670
  gs.append(s)
671
 
672
  s = torch.stack(ss).squeeze()
@@ -708,7 +710,7 @@ def main(config_path):
708
  en = torch.stack(en)
709
  p_en = torch.stack(p_en)
710
  gt = torch.stack(gt).detach()
711
- s = model.predictor_encoder(gt)
712
 
713
  F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
714
 
 
1
  # load packages
2
 
3
+
4
+
5
  import random
6
  import yaml
7
  import time
 
350
  for bib in range(len(mel_input_length)):
351
  mel_length = int(mel_input_length[bib].item())
352
  mel = mels[bib, :, :mel_input_length[bib]]
353
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
354
  ss.append(s)
355
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
356
  gs.append(s)
357
 
358
  s_dur = torch.stack(ss).squeeze() # global prosodic styles
 
434
  # if gt.size(-1) < 80:
435
  # continue
436
 
437
+ s = model.style_encoder(gt.unsqueeze(0))
438
+ s_dur = model.predictor_encoder(gt.unsqueeze(0))
439
 
440
  with torch.no_grad():
441
  F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
 
666
  for bib in range(len(mel_input_length)):
667
  mel_length = int(mel_input_length[bib].item())
668
  mel = mels[bib, :, :mel_input_length[bib]]
669
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
670
  ss.append(s)
671
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
672
  gs.append(s)
673
 
674
  s = torch.stack(ss).squeeze()
 
710
  en = torch.stack(en)
711
  p_en = torch.stack(p_en)
712
  gt = torch.stack(gt).detach()
713
+ s = model.predictor_encoder(gt.unsqueeze(0))
714
 
715
  F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
716