Respair commited on
Commit
e4d34e8
·
verified ·
1 Parent(s): 302641c

Upload Sana/finetune.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. Sana/finetune.py +796 -0
Sana/finetune.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+
3
+ # Use this script if the DDP scripts didn't work out for you.
4
+ # DDP スクリプトが機能しなかった場合は、このスクリプトを使用します。
5
+
6
+ import random
7
+ import yaml
8
+ import time
9
+ from munch import Munch
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import torchaudio
15
+ import librosa
16
+ import click
17
+ import shutil
18
+ import warnings
19
+ warnings.simplefilter('ignore')
20
+ from torch.utils.tensorboard import SummaryWriter
21
+
22
+ from meldataset import build_dataloader
23
+
24
+ from Utils.ASR.models import ASRCNN
25
+ from Utils.JDC.model import JDCNet
26
+ from Utils.PLBERT.util import load_plbert
27
+
28
+ from models import *
29
+ from losses import *
30
+ from utils import *
31
+
32
+ from Modules.slmadv import SLMAdversarialLoss
33
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
34
+
35
+ from optimizers import build_optimizer
36
+
37
+ from accelerate import Accelerator, DistributedDataParallelKwargs
38
+ from accelerate.utils import tqdm, ProjectConfiguration
39
+
40
+ try:
41
+ import wandb
42
+ except ImportError:
43
+ wandb = None
44
+
45
+ # from Utils.fsdp_patch import replace_fsdp_state_dict_type
46
+
47
+ # replace_fsdp_state_dict_type()
48
+ from accelerate import Accelerator
49
+ from accelerate.utils import LoggerType
50
+ from accelerate import DistributedDataParallelKwargs
51
+
52
+ from torch.utils.tensorboard import SummaryWriter
53
+
54
+ import logging
55
+ from accelerate.logging import get_logger
56
+ logger = get_logger(__name__, log_level="DEBUG")
57
+
58
+
59
+ # handler.setLevel(logging.DEBUG)
60
+ # logger.addHandler(handler)
61
+ # simple fix for dataparallel that allows access to class attributes
62
+ class MyDataParallel(torch.nn.DataParallel):
63
+ def __getattr__(self, name):
64
+ try:
65
+ return super().__getattr__(name)
66
+ except AttributeError:
67
+ return getattr(self.module, name)
68
+
69
+
70
+ @click.command()
71
+ @click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
72
+ def main(config_path):
73
+ config = yaml.safe_load(open(config_path))
74
+
75
+ save_iter = 1000
76
+
77
+ log_dir = config['log_dir']
78
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
79
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
80
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
81
+ accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs], mixed_precision='bf16')
82
+ if accelerator.is_main_process:
83
+ writer = SummaryWriter(log_dir + "/tensorboard")
84
+
85
+ # write logs
86
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
87
+ file_handler.setLevel(logging.DEBUG)
88
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
89
+ logger.logger.addHandler(file_handler)
90
+
91
+
92
+
93
+ batch_size = config.get('batch_size', 10)
94
+
95
+ epochs = config.get('epochs', 200)
96
+ save_freq = config.get('save_freq', 2)
97
+ log_interval = config.get('log_interval', 10)
98
+ saving_epoch = config.get('save_freq', 2)
99
+
100
+ data_params = config.get('data_params', None)
101
+ sr = config['preprocess_params'].get('sr', 24000)
102
+ train_path = data_params['train_data']
103
+ val_path = data_params['val_data']
104
+ root_path = data_params['root_path']
105
+ min_length = data_params['min_length']
106
+ OOD_data = data_params['OOD_data']
107
+
108
+ max_len = config.get('max_len', 200)
109
+
110
+ loss_params = Munch(config['loss_params'])
111
+ diff_epoch = loss_params.diff_epoch
112
+ joint_epoch = loss_params.joint_epoch
113
+
114
+ optimizer_params = Munch(config['optimizer_params'])
115
+
116
+ train_list, val_list = get_data_path_list(train_path, val_path)
117
+ device = 'cuda'
118
+
119
+ train_dataloader = build_dataloader(train_list,
120
+ root_path,
121
+ OOD_data=OOD_data,
122
+ min_length=min_length,
123
+ batch_size=batch_size,
124
+ num_workers=16,
125
+ dataset_config={},
126
+ device=device)
127
+
128
+ val_dataloader = build_dataloader(val_list,
129
+ root_path,
130
+ OOD_data=OOD_data,
131
+ min_length=min_length,
132
+ batch_size=batch_size,
133
+ validation=True,
134
+ num_workers=4,
135
+ device=device,
136
+ dataset_config={})
137
+
138
+
139
+ with accelerator.main_process_first():
140
+ # load pretrained ASR model
141
+ ASR_config = config.get('ASR_config', False)
142
+ ASR_path = config.get('ASR_path', False)
143
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
144
+
145
+ # load pretrained F0 model
146
+ F0_path = config.get('F0_path', False)
147
+ pitch_extractor = load_F0_models(F0_path)
148
+
149
+ # load BERT model
150
+ from Utils.PLBERT.util import load_plbert
151
+ BERT_path = config.get('PLBERT_dir', False)
152
+ plbert = load_plbert(BERT_path)
153
+
154
+ scheduler_params = {
155
+ "max_lr": float(config['optimizer_params'].get('lr', 1e-4)),
156
+ "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
157
+ "epochs": epochs,
158
+ "steps_per_epoch": len(train_dataloader),
159
+ }
160
+
161
+
162
+ # build model
163
+ model_params = recursive_munch(config['model_params'])
164
+ multispeaker = model_params.multispeaker
165
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
166
+ _ = [model[key].to(device) for key in model]
167
+
168
+
169
+
170
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
171
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
172
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
173
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
174
+
175
+
176
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
177
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model},
178
+ lr=float(config['optimizer_params'].get('lr', 1e-4)))
179
+
180
+ for k, v in optimizer.optimizers.items():
181
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
182
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
183
+
184
+
185
+ for k in model:
186
+ model[k] = accelerator.prepare(model[k])
187
+
188
+ train_dataloader, val_dataloader = accelerator.prepare(
189
+ train_dataloader, val_dataloader
190
+ )
191
+
192
+ start_epoch = 0
193
+ iters = 0
194
+
195
+
196
+
197
+ with accelerator.main_process_first():
198
+ if config.get('pretrained_model', '') and config.get('second_stage_load_pretrained', False):
199
+ model, optimizer, start_epoch, iters = load_checkpoint(
200
+ model,
201
+ optimizer,
202
+ config['pretrained_model'],
203
+ load_only_params=config.get('load_only_params', True)
204
+ )
205
+ accelerator.print('Loading the checkpoint at %s ...' % config['pretrained_model'])
206
+ elif config.get('first_stage_path', ''):
207
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
208
+ accelerator.print('Loading the first stage model at %s ...' % first_stage_path)
209
+ model, optimizer, start_epoch, iters = load_checkpoint(
210
+ model,
211
+ optimizer,
212
+ first_stage_path,
213
+ load_only_params=True,
214
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
215
+
216
+ else:
217
+ raise ValueError('You need to specify a pretrained model or a first stage model path.')
218
+
219
+
220
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
221
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
222
+ wl = WavLMLoss(model_params.slm.model,
223
+ model.wd,
224
+ sr,
225
+ model_params.slm.sr).to(device)
226
+
227
+ wl = wl.eval()
228
+
229
+
230
+ sampler = DiffusionSampler(
231
+ model.diffusion.module.diffusion,
232
+ sampler=ADPM2Sampler(),
233
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
234
+ clamp=False
235
+ )
236
+
237
+
238
+
239
+
240
+ # adjust BERT learning rate
241
+ for g in optimizer.optimizers['bert'].param_groups:
242
+ g['betas'] = (0.9, 0.99)
243
+ g['lr'] = optimizer_params.bert_lr
244
+ g['initial_lr'] = optimizer_params.bert_lr
245
+ g['min_lr'] = 0
246
+ g['weight_decay'] = 0.01
247
+
248
+ # adjust acoustic module learning rate
249
+ for module in ["decoder", "style_encoder"]:
250
+ for g in optimizer.optimizers[module].param_groups:
251
+ g['betas'] = (0.0, 0.99)
252
+ g['lr'] = optimizer_params.ft_lr
253
+ g['initial_lr'] = optimizer_params.ft_lr
254
+ g['min_lr'] = 0
255
+ g['weight_decay'] = 1e-4
256
+
257
+ # # load models if there is a model
258
+ # if load_pretrained:
259
+ # model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
260
+ # load_only_params=config.get('load_only_params', True))
261
+
262
+
263
+ try:
264
+ n_down = model.text_aligner.module.n_down
265
+ except:
266
+ n_down = model.text_aligner.n_down
267
+
268
+ best_loss = float('inf') # best test loss
269
+ loss_train_record = list([])
270
+ loss_test_record = list([])
271
+ iters = 0
272
+
273
+ criterion = nn.L1Loss() # F0 loss (regression)
274
+ torch.cuda.empty_cache()
275
+
276
+ stft_loss = MultiResolutionSTFTLoss().to(device)
277
+
278
+ print('BERT', optimizer.optimizers['bert'])
279
+ print('decoder', optimizer.optimizers['decoder'])
280
+
281
+ start_ds = False
282
+
283
+ running_std = []
284
+
285
+ slmadv_params = Munch(config['slmadv_params'])
286
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
287
+ slmadv_params.min_len,
288
+ slmadv_params.max_len,
289
+ batch_percentage=slmadv_params.batch_percentage,
290
+ skip_update=slmadv_params.iter,
291
+ sig=slmadv_params.sig
292
+ )
293
+
294
+
295
+ for epoch in range(start_epoch, epochs):
296
+ running_loss = 0
297
+ start_time = time.time()
298
+
299
+ _ = [model[key].eval() for key in model]
300
+
301
+ model.text_aligner.train()
302
+ model.text_encoder.train()
303
+
304
+ model.predictor.train()
305
+ model.bert_encoder.train()
306
+ model.bert.train()
307
+ model.msd.train()
308
+ model.mpd.train()
309
+
310
+ for i, batch in enumerate(train_dataloader):
311
+ waves = batch[0]
312
+ batch = [b.to(device) for b in batch[1:]]
313
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
314
+ with torch.no_grad():
315
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
316
+ mel_mask = length_to_mask(mel_input_length).to(device)
317
+ text_mask = length_to_mask(input_lengths).to(texts.device)
318
+
319
+ # compute reference styles
320
+ if multispeaker and epoch >= diff_epoch:
321
+ ref_ss = model.style_encoder(ref_mels)
322
+ ref_sp = model.predictor_encoder(ref_mels)
323
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
324
+
325
+ try:
326
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
327
+ s2s_attn = s2s_attn.transpose(-1, -2)
328
+ s2s_attn = s2s_attn[..., 1:]
329
+ s2s_attn = s2s_attn.transpose(-1, -2)
330
+ except:
331
+ continue
332
+
333
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
334
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
335
+
336
+ # encode
337
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
338
+
339
+ # 50% of chance of using monotonic version
340
+ if bool(random.getrandbits(1)):
341
+ asr = (t_en @ s2s_attn)
342
+ else:
343
+ asr = (t_en @ s2s_attn_mono)
344
+
345
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
346
+
347
+ # compute the style of the entire utterance
348
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
349
+ ss = []
350
+ gs = []
351
+ for bib in range(len(mel_input_length)):
352
+ mel_length = int(mel_input_length[bib].item())
353
+ mel = mels[bib, :, :mel_input_length[bib]]
354
+ s = model.predictor_encoder(mel.unsqueeze(0))
355
+ ss.append(s)
356
+ s = model.style_encoder(mel.unsqueeze(0))
357
+ gs.append(s)
358
+
359
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
360
+ gs = torch.stack(gs).squeeze() # global acoustic styles
361
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
362
+
363
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
364
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
365
+
366
+ # denoiser training
367
+ if epoch >= diff_epoch:
368
+ num_steps = np.random.randint(3, 5)
369
+
370
+ if model_params.diffusion.dist.estimate_sigma_data:
371
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
372
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
373
+
374
+ if multispeaker:
375
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
376
+ embedding=bert_dur,
377
+ embedding_scale=1,
378
+ features=ref, # reference from the same speaker as the embedding
379
+ embedding_mask_proba=0.1,
380
+ num_steps=num_steps).squeeze(1)
381
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
382
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
383
+ else:
384
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
385
+ embedding=bert_dur,
386
+ embedding_scale=1,
387
+ embedding_mask_proba=0.1,
388
+ num_steps=num_steps).squeeze(1)
389
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
390
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
391
+ else:
392
+ loss_sty = 0
393
+ loss_diff = 0
394
+
395
+
396
+ s_loss = 0
397
+
398
+
399
+ d, p = model.predictor(d_en, s_dur,
400
+ input_lengths,
401
+ s2s_attn_mono,
402
+ text_mask)
403
+
404
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
405
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
406
+ en = []
407
+ gt = []
408
+ p_en = []
409
+ wav = []
410
+ st = []
411
+
412
+ for bib in range(len(mel_input_length)):
413
+ mel_length = int(mel_input_length[bib].item() / 2)
414
+
415
+ random_start = np.random.randint(0, mel_length - mel_len)
416
+ en.append(asr[bib, :, random_start:random_start+mel_len])
417
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
418
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
419
+
420
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
421
+ wav.append(torch.from_numpy(y).to(device))
422
+
423
+ # style reference (better to be different from the GT)
424
+ random_start = np.random.randint(0, mel_length - mel_len_st)
425
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
426
+
427
+ wav = torch.stack(wav).float().detach()
428
+
429
+ en = torch.stack(en)
430
+ p_en = torch.stack(p_en)
431
+ gt = torch.stack(gt).detach()
432
+ st = torch.stack(st).detach()
433
+
434
+
435
+ # if gt.size(-1) < 80:
436
+ # continue
437
+
438
+ s = model.style_encoder(gt)
439
+ s_dur = model.predictor_encoder(gt)
440
+
441
+ with torch.no_grad():
442
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
443
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
444
+
445
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
446
+
447
+ y_rec_gt = wav.unsqueeze(1)
448
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
449
+
450
+ wav = y_rec_gt
451
+
452
+ # F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
453
+
454
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
455
+
456
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
457
+
458
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
459
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
460
+
461
+ optimizer.zero_grad()
462
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
463
+ d_loss.backward()
464
+ optimizer.step('msd')
465
+ optimizer.step('mpd')
466
+
467
+ # generator loss
468
+ optimizer.zero_grad()
469
+
470
+ loss_mel = stft_loss(y_rec, wav)
471
+ loss_gen_all = gl(wav, y_rec).mean()
472
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
473
+
474
+ loss_ce = 0
475
+ loss_dur = 0
476
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
477
+ _s2s_pred = _s2s_pred[:_text_length, :]
478
+ _text_input = _text_input[:_text_length].long()
479
+ _s2s_trg = torch.zeros_like(_s2s_pred)
480
+ for p in range(_s2s_trg.shape[0]):
481
+ _s2s_trg[p, :_text_input[p]] = 1
482
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
483
+
484
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
485
+ _text_input[1:_text_length-1])
486
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
487
+
488
+ loss_ce /= texts.size(0)
489
+ loss_dur /= texts.size(0)
490
+
491
+ loss_s2s = 0
492
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
493
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
494
+ loss_s2s /= texts.size(0)
495
+
496
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
497
+
498
+ g_loss = loss_params.lambda_mel * loss_mel + \
499
+ loss_params.lambda_F0 * loss_F0_rec + \
500
+ loss_params.lambda_ce * loss_ce + \
501
+ loss_params.lambda_norm * loss_norm_rec + \
502
+ loss_params.lambda_dur * loss_dur + \
503
+ loss_params.lambda_gen * loss_gen_all + \
504
+ loss_params.lambda_slm * loss_lm + \
505
+ loss_params.lambda_sty * loss_sty + \
506
+ loss_params.lambda_diff * loss_diff + \
507
+ loss_params.lambda_mono * loss_mono + \
508
+ loss_params.lambda_s2s * loss_s2s
509
+
510
+ running_loss += loss_mel.item()
511
+ g_loss.backward()
512
+ if torch.isnan(g_loss):
513
+ from IPython.core.debugger import set_trace
514
+ set_trace()
515
+
516
+ optimizer.step('bert_encoder')
517
+ optimizer.step('bert')
518
+ optimizer.step('predictor')
519
+ optimizer.step('predictor_encoder')
520
+ optimizer.step('style_encoder')
521
+ optimizer.step('decoder')
522
+
523
+ optimizer.step('text_encoder')
524
+ optimizer.step('text_aligner')
525
+
526
+ if epoch >= diff_epoch:
527
+ optimizer.step('diffusion')
528
+
529
+ d_loss_slm, loss_gen_lm = 0, 0
530
+ if epoch >= joint_epoch:
531
+ # randomly pick whether to use in-distribution text
532
+ if np.random.rand() < 0.5:
533
+ use_ind = True
534
+ else:
535
+ use_ind = False
536
+
537
+ if use_ind:
538
+ ref_lengths = input_lengths
539
+ ref_texts = texts
540
+
541
+ slm_out = slmadv(i,
542
+ y_rec_gt,
543
+ y_rec_gt_pred,
544
+ waves,
545
+ mel_input_length,
546
+ ref_texts,
547
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
548
+
549
+ if slm_out is not None:
550
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
551
+
552
+ # SLM generator loss
553
+ optimizer.zero_grad()
554
+ loss_gen_lm.backward()
555
+
556
+ # compute the gradient norm
557
+ total_norm = {}
558
+ for key in model.keys():
559
+ total_norm[key] = 0
560
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
561
+ for p in parameters:
562
+ param_norm = p.grad.detach().data.norm(2)
563
+ total_norm[key] += param_norm.item() ** 2
564
+ total_norm[key] = total_norm[key] ** 0.5
565
+
566
+ # gradient scaling
567
+ if total_norm['predictor'] > slmadv_params.thresh:
568
+ for key in model.keys():
569
+ for p in model[key].parameters():
570
+ if p.grad is not None:
571
+ p.grad *= (1 / total_norm['predictor'])
572
+
573
+ for p in model.predictor.duration_proj.parameters():
574
+ if p.grad is not None:
575
+ p.grad *= slmadv_params.scale
576
+
577
+ for p in model.predictor.lstm.parameters():
578
+ if p.grad is not None:
579
+ p.grad *= slmadv_params.scale
580
+
581
+ for p in model.diffusion.parameters():
582
+ if p.grad is not None:
583
+ p.grad *= slmadv_params.scale
584
+
585
+ optimizer.step('bert_encoder')
586
+ optimizer.step('bert')
587
+ optimizer.step('predictor')
588
+ optimizer.step('diffusion')
589
+
590
+ # SLM discriminator loss
591
+ if d_loss_slm != 0:
592
+ optimizer.zero_grad()
593
+ d_loss_slm.backward(retain_graph=True)
594
+ optimizer.step('wd')
595
+
596
+ iters = iters + 1
597
+
598
+ if (i+1)%log_interval == 0 and accelerator.is_main_process:
599
+ log_print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
600
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono), logger)
601
+
602
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
603
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
604
+ writer.add_scalar('train/d_loss', d_loss, iters)
605
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
606
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
607
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
608
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
609
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
610
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
611
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
612
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
613
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
614
+
615
+ running_loss = 0
616
+
617
+ print('Time elasped:', time.time()-start_time)
618
+
619
+ if (i+1)%save_iter == 0 and accelerator.is_main_process:
620
+
621
+ accelerator.print(f'Saving on step {epoch*len(train_dataloader)+i}...')
622
+ state = {
623
+ 'net': {key: model[key].state_dict() for key in model},
624
+ 'optimizer': optimizer.state_dict(),
625
+ 'iters': iters,
626
+ 'epoch': epoch,
627
+ }
628
+ save_path = osp.join(log_dir, f'Sana_Finetune__{epoch*len(train_dataloader)+i}.pth')
629
+ torch.save(state, save_path)
630
+
631
+
632
+ loss_test = 0
633
+ loss_align = 0
634
+ loss_f = 0
635
+ _ = [model[key].eval() for key in model]
636
+
637
+ with torch.no_grad():
638
+ iters_test = 0
639
+ for batch_idx, batch in enumerate(val_dataloader):
640
+ optimizer.zero_grad()
641
+
642
+ try:
643
+ waves = batch[0]
644
+ batch = [b.to(device) for b in batch[1:]]
645
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
646
+ with torch.no_grad():
647
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
648
+ text_mask = length_to_mask(input_lengths).to(texts.device)
649
+
650
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
651
+ s2s_attn = s2s_attn.transpose(-1, -2)
652
+ s2s_attn = s2s_attn[..., 1:]
653
+ s2s_attn = s2s_attn.transpose(-1, -2)
654
+
655
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
656
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
657
+
658
+ # encode
659
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
660
+ asr = (t_en @ s2s_attn_mono)
661
+
662
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
663
+
664
+ ss = []
665
+ gs = []
666
+
667
+ for bib in range(len(mel_input_length)):
668
+ mel_length = int(mel_input_length[bib].item())
669
+ mel = mels[bib, :, :mel_input_length[bib]]
670
+ s = model.predictor_encoder(mel.unsqueeze(0))
671
+ ss.append(s)
672
+ s = model.style_encoder(mel.unsqueeze(0))
673
+ gs.append(s)
674
+
675
+ s = torch.stack(ss).squeeze()
676
+ gs = torch.stack(gs).squeeze()
677
+ s_trg = torch.cat([s, gs], dim=-1).detach()
678
+
679
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
680
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
681
+ d, p = model.predictor(d_en, s,
682
+ input_lengths,
683
+ s2s_attn_mono,
684
+ text_mask)
685
+ # get clips
686
+ # mel_len = int(mel_input_length.min().item() / 2 - 1)
687
+
688
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
689
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
690
+
691
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
692
+ en = []
693
+ gt = []
694
+
695
+ p_en = []
696
+ wav = []
697
+
698
+ for bib in range(len(mel_input_length)):
699
+ mel_length = int(mel_input_length[bib].item() / 2)
700
+
701
+ random_start = np.random.randint(0, mel_length - mel_len)
702
+ en.append(asr[bib, :, random_start:random_start+mel_len])
703
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
704
+
705
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
706
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
707
+ wav.append(torch.from_numpy(y).to(device))
708
+
709
+ wav = torch.stack(wav).float().detach()
710
+
711
+ en = torch.stack(en)
712
+ p_en = torch.stack(p_en)
713
+ gt = torch.stack(gt).detach()
714
+ s = model.predictor_encoder(gt)
715
+
716
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
717
+
718
+ loss_dur = 0
719
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
720
+ _s2s_pred = _s2s_pred[:_text_length, :]
721
+ _text_input = _text_input[:_text_length].long()
722
+ _s2s_trg = torch.zeros_like(_s2s_pred)
723
+ for bib in range(_s2s_trg.shape[0]):
724
+ _s2s_trg[bib, :_text_input[bib]] = 1
725
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
726
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
727
+ _text_input[1:_text_length-1])
728
+
729
+ loss_dur /= texts.size(0)
730
+
731
+ s = model.style_encoder(gt)
732
+
733
+ try:
734
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
735
+ except Exception as e:
736
+ accelerator.print(str(e))
737
+ accelerator.print(F0_real.size())
738
+ accelerator.print(F0_real.squeeze(0).size())
739
+
740
+ try:
741
+
742
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
743
+
744
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
745
+
746
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
747
+
748
+ except Exception as e:
749
+ accelerator.print(str(e))
750
+ accelerator.print(F0_real.size())
751
+ accelerator.print(F0_real.squeeze(0).size())
752
+
753
+
754
+ loss_test += accelerator.gather(loss_mel).mean()
755
+ loss_align += accelerator.gather(loss_dur).mean()
756
+ loss_f += accelerator.gather(loss_F0).mean()
757
+
758
+ iters_test += 1
759
+ except Exception as e:
760
+ accelerator.print(e)
761
+ accelerator.print('ooh something wrong!')
762
+ iters_test +=1
763
+ continue
764
+ if accelerator.is_main_process:
765
+ print('Epochs:', epoch + 1)
766
+ log_print('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n', logger)
767
+ print('\n\n\n')
768
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
769
+ writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
770
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
771
+
772
+
773
+ if (epoch + 1) % save_freq == 0 :
774
+ if (loss_test / iters_test) < best_loss:
775
+ best_loss = loss_test / iters_test
776
+ print('Saving..')
777
+ state = {
778
+ 'net': {key: model[key].state_dict() for key in model},
779
+ 'optimizer': optimizer.state_dict(),
780
+ 'iters': iters,
781
+ 'val_loss': loss_test / iters_test,
782
+ 'epoch': epoch,
783
+ }
784
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
785
+ torch.save(state, save_path)
786
+
787
+ # if estimate sigma, save the estimated simga
788
+ if model_params.diffusion.dist.estimate_sigma_data:
789
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
790
+
791
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
792
+ yaml.dump(config, outfile, default_flow_style=True)
793
+
794
+
795
+ if __name__=="__main__":
796
+ main()