Commit 
							
							·
						
						fc5b11b
	
1
								Parent(s):
							
							883b1e0
								
Upload train_ms.py with huggingface_hub
Browse files- train_ms.py +299 -0
    	
        train_ms.py
    ADDED
    
    | @@ -0,0 +1,299 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import argparse
         | 
| 4 | 
            +
            import itertools
         | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torch import nn, optim
         | 
| 8 | 
            +
            from torch.nn import functional as F
         | 
| 9 | 
            +
            from torch.utils.data import DataLoader
         | 
| 10 | 
            +
            from torch.utils.tensorboard import SummaryWriter
         | 
| 11 | 
            +
            import torch.multiprocessing as mp
         | 
| 12 | 
            +
            import torch.distributed as dist
         | 
| 13 | 
            +
            from torch.nn.parallel import DistributedDataParallel as DDP
         | 
| 14 | 
            +
            from torch.cuda.amp import autocast, GradScaler
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import librosa
         | 
| 17 | 
            +
            import logging
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            logging.getLogger('numba').setLevel(logging.WARNING)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import commons
         | 
| 22 | 
            +
            import utils
         | 
| 23 | 
            +
            from data_utils import (
         | 
| 24 | 
            +
              TextAudioSpeakerLoader,
         | 
| 25 | 
            +
              TextAudioSpeakerCollate,
         | 
| 26 | 
            +
              DistributedBucketSampler
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
            from models import (
         | 
| 29 | 
            +
              SynthesizerTrn,
         | 
| 30 | 
            +
              MultiPeriodDiscriminator,
         | 
| 31 | 
            +
            )
         | 
| 32 | 
            +
            from losses import (
         | 
| 33 | 
            +
              generator_loss,
         | 
| 34 | 
            +
              discriminator_loss,
         | 
| 35 | 
            +
              feature_loss,
         | 
| 36 | 
            +
              kl_loss
         | 
| 37 | 
            +
            )
         | 
| 38 | 
            +
            from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
         | 
| 39 | 
            +
            from text.symbols import symbols
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            torch.backends.cudnn.benchmark = True
         | 
| 43 | 
            +
            global_step = 0
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def main():
         | 
| 47 | 
            +
              """Assume Single Node Multi GPUs Training Only"""
         | 
| 48 | 
            +
              assert torch.cuda.is_available(), "CPU training is not allowed."
         | 
| 49 | 
            +
             | 
| 50 | 
            +
              n_gpus = torch.cuda.device_count()
         | 
| 51 | 
            +
              os.environ['MASTER_ADDR'] = 'localhost'
         | 
| 52 | 
            +
              os.environ['MASTER_PORT'] = '8000'
         | 
| 53 | 
            +
             | 
| 54 | 
            +
              hps = utils.get_hparams()
         | 
| 55 | 
            +
              mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def run(rank, n_gpus, hps):
         | 
| 59 | 
            +
              global global_step
         | 
| 60 | 
            +
              if rank == 0:
         | 
| 61 | 
            +
                logger = utils.get_logger(hps.model_dir)
         | 
| 62 | 
            +
                logger.info(hps)
         | 
| 63 | 
            +
                utils.check_git_hash(hps.model_dir)
         | 
| 64 | 
            +
                writer = SummaryWriter(log_dir=hps.model_dir)
         | 
| 65 | 
            +
                writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
         | 
| 66 | 
            +
             | 
| 67 | 
            +
              dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
         | 
| 68 | 
            +
              torch.manual_seed(hps.train.seed)
         | 
| 69 | 
            +
              torch.cuda.set_device(rank)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
              train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
         | 
| 72 | 
            +
              train_sampler = DistributedBucketSampler(
         | 
| 73 | 
            +
                  train_dataset,
         | 
| 74 | 
            +
                  hps.train.batch_size,
         | 
| 75 | 
            +
                  [32,300,400,500,600,700,800,900,1000],
         | 
| 76 | 
            +
                  num_replicas=n_gpus,
         | 
| 77 | 
            +
                  rank=rank,
         | 
| 78 | 
            +
                  shuffle=True)
         | 
| 79 | 
            +
              collate_fn = TextAudioSpeakerCollate()
         | 
| 80 | 
            +
              train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True,
         | 
| 81 | 
            +
                  collate_fn=collate_fn, batch_sampler=train_sampler)
         | 
| 82 | 
            +
              if rank == 0:
         | 
| 83 | 
            +
                eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
         | 
| 84 | 
            +
                eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False,
         | 
| 85 | 
            +
                    batch_size=hps.train.batch_size, pin_memory=True,
         | 
| 86 | 
            +
                    drop_last=False, collate_fn=collate_fn)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
              net_g = SynthesizerTrn(
         | 
| 89 | 
            +
                  len(symbols),
         | 
| 90 | 
            +
                  hps.data.filter_length // 2 + 1,
         | 
| 91 | 
            +
                  hps.train.segment_size // hps.data.hop_length,
         | 
| 92 | 
            +
                  n_speakers=hps.data.n_speakers,
         | 
| 93 | 
            +
                  **hps.model).cuda(rank)
         | 
| 94 | 
            +
              net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
         | 
| 95 | 
            +
              optim_g = torch.optim.AdamW(
         | 
| 96 | 
            +
                  net_g.parameters(), 
         | 
| 97 | 
            +
                  hps.train.learning_rate, 
         | 
| 98 | 
            +
                  betas=hps.train.betas, 
         | 
| 99 | 
            +
                  eps=hps.train.eps)
         | 
| 100 | 
            +
              optim_d = torch.optim.AdamW(
         | 
| 101 | 
            +
                  net_d.parameters(),
         | 
| 102 | 
            +
                  hps.train.learning_rate, 
         | 
| 103 | 
            +
                  betas=hps.train.betas, 
         | 
| 104 | 
            +
                  eps=hps.train.eps)
         | 
| 105 | 
            +
              net_g = DDP(net_g, device_ids=[rank])
         | 
| 106 | 
            +
              net_d = DDP(net_d, device_ids=[rank])
         | 
| 107 | 
            +
             | 
| 108 | 
            +
              try:
         | 
| 109 | 
            +
                _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
         | 
| 110 | 
            +
                _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d)
         | 
| 111 | 
            +
                global_step = (epoch_str - 1) * len(train_loader)
         | 
| 112 | 
            +
              except:
         | 
| 113 | 
            +
                epoch_str = 1
         | 
| 114 | 
            +
                global_step = 0
         | 
| 115 | 
            +
             | 
| 116 | 
            +
              scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
         | 
| 117 | 
            +
              scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
              scaler = GradScaler(enabled=hps.train.fp16_run)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
              for epoch in range(epoch_str, hps.train.epochs + 1):
         | 
| 122 | 
            +
                if rank==0:
         | 
| 123 | 
            +
                  train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
         | 
| 124 | 
            +
                else:
         | 
| 125 | 
            +
                  train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
         | 
| 126 | 
            +
                scheduler_g.step()
         | 
| 127 | 
            +
                scheduler_d.step()
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
         | 
| 131 | 
            +
              net_g, net_d = nets
         | 
| 132 | 
            +
              optim_g, optim_d = optims
         | 
| 133 | 
            +
              scheduler_g, scheduler_d = schedulers
         | 
| 134 | 
            +
              train_loader, eval_loader = loaders
         | 
| 135 | 
            +
              if writers is not None:
         | 
| 136 | 
            +
                writer, writer_eval = writers
         | 
| 137 | 
            +
             | 
| 138 | 
            +
              train_loader.batch_sampler.set_epoch(epoch)
         | 
| 139 | 
            +
              global global_step
         | 
| 140 | 
            +
             | 
| 141 | 
            +
              net_g.train()
         | 
| 142 | 
            +
              net_d.train()
         | 
| 143 | 
            +
              for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(train_loader):
         | 
| 144 | 
            +
                x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
         | 
| 145 | 
            +
                spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
         | 
| 146 | 
            +
                y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
         | 
| 147 | 
            +
                speakers = speakers.cuda(rank, non_blocking=True)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                with autocast(enabled=hps.train.fp16_run):
         | 
| 150 | 
            +
                  y_hat, l_length, attn, ids_slice, x_mask, z_mask,\
         | 
| 151 | 
            +
                  (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                  mel = spec_to_mel_torch(
         | 
| 154 | 
            +
                      spec, 
         | 
| 155 | 
            +
                      hps.data.filter_length, 
         | 
| 156 | 
            +
                      hps.data.n_mel_channels, 
         | 
| 157 | 
            +
                      hps.data.sampling_rate,
         | 
| 158 | 
            +
                      hps.data.mel_fmin, 
         | 
| 159 | 
            +
                      hps.data.mel_fmax)
         | 
| 160 | 
            +
                  y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
         | 
| 161 | 
            +
                  y_hat_mel = mel_spectrogram_torch(
         | 
| 162 | 
            +
                      y_hat.squeeze(1), 
         | 
| 163 | 
            +
                      hps.data.filter_length, 
         | 
| 164 | 
            +
                      hps.data.n_mel_channels, 
         | 
| 165 | 
            +
                      hps.data.sampling_rate, 
         | 
| 166 | 
            +
                      hps.data.hop_length, 
         | 
| 167 | 
            +
                      hps.data.win_length, 
         | 
| 168 | 
            +
                      hps.data.mel_fmin, 
         | 
| 169 | 
            +
                      hps.data.mel_fmax
         | 
| 170 | 
            +
                  )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                  y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice 
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                  # Discriminator
         | 
| 175 | 
            +
                  y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
         | 
| 176 | 
            +
                  with autocast(enabled=False):
         | 
| 177 | 
            +
                    loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
         | 
| 178 | 
            +
                    loss_disc_all = loss_disc
         | 
| 179 | 
            +
                optim_d.zero_grad()
         | 
| 180 | 
            +
                scaler.scale(loss_disc_all).backward()
         | 
| 181 | 
            +
                scaler.unscale_(optim_d)
         | 
| 182 | 
            +
                grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
         | 
| 183 | 
            +
                scaler.step(optim_d)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                with autocast(enabled=hps.train.fp16_run):
         | 
| 186 | 
            +
                  # Generator
         | 
| 187 | 
            +
                  y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
         | 
| 188 | 
            +
                  with autocast(enabled=False):
         | 
| 189 | 
            +
                    loss_dur = torch.sum(l_length.float())
         | 
| 190 | 
            +
                    loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
         | 
| 191 | 
            +
                    loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    loss_fm = feature_loss(fmap_r, fmap_g)
         | 
| 194 | 
            +
                    loss_gen, losses_gen = generator_loss(y_d_hat_g)
         | 
| 195 | 
            +
                    loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
         | 
| 196 | 
            +
                optim_g.zero_grad()
         | 
| 197 | 
            +
                scaler.scale(loss_gen_all).backward()
         | 
| 198 | 
            +
                scaler.unscale_(optim_g)
         | 
| 199 | 
            +
                grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
         | 
| 200 | 
            +
                scaler.step(optim_g)
         | 
| 201 | 
            +
                scaler.update()
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                if rank==0:
         | 
| 204 | 
            +
                  if global_step % hps.train.log_interval == 0:
         | 
| 205 | 
            +
                    lr = optim_g.param_groups[0]['lr']
         | 
| 206 | 
            +
                    losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
         | 
| 207 | 
            +
                    logger.info('Train Epoch: {} [{:.0f}%]'.format(
         | 
| 208 | 
            +
                      epoch,
         | 
| 209 | 
            +
                      100. * batch_idx / len(train_loader)))
         | 
| 210 | 
            +
                    logger.info([x.item() for x in losses] + [global_step, lr])
         | 
| 211 | 
            +
                    
         | 
| 212 | 
            +
                    scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
         | 
| 213 | 
            +
                    scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl})
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
         | 
| 216 | 
            +
                    scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
         | 
| 217 | 
            +
                    scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
         | 
| 218 | 
            +
                    image_dict = { 
         | 
| 219 | 
            +
                        "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
         | 
| 220 | 
            +
                        "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), 
         | 
| 221 | 
            +
                        "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
         | 
| 222 | 
            +
                        "all/attn": utils.plot_alignment_to_numpy(attn[0,0].data.cpu().numpy())
         | 
| 223 | 
            +
                    }
         | 
| 224 | 
            +
                    utils.summarize(
         | 
| 225 | 
            +
                      writer=writer,
         | 
| 226 | 
            +
                      global_step=global_step, 
         | 
| 227 | 
            +
                      images=image_dict,
         | 
| 228 | 
            +
                      scalars=scalar_dict)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                  if global_step % hps.train.eval_interval == 0:
         | 
| 231 | 
            +
                    evaluate(hps, net_g, eval_loader, writer_eval)
         | 
| 232 | 
            +
                    utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, '/content/drive/MyDrive/Genshin_ms/G_ms.pth')
         | 
| 233 | 
            +
                    utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, '/content/drive/MyDrive/Genshin_ms/D_ms.pth')
         | 
| 234 | 
            +
                global_step += 1
         | 
| 235 | 
            +
              
         | 
| 236 | 
            +
              if rank == 0:
         | 
| 237 | 
            +
                logger.info('====> Epoch: {}'.format(epoch))
         | 
| 238 | 
            +
             | 
| 239 | 
            +
             
         | 
| 240 | 
            +
            def evaluate(hps, generator, eval_loader, writer_eval):
         | 
| 241 | 
            +
                generator.eval()
         | 
| 242 | 
            +
                with torch.no_grad():
         | 
| 243 | 
            +
                  for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(eval_loader):
         | 
| 244 | 
            +
                    x, x_lengths = x.cuda(0), x_lengths.cuda(0)
         | 
| 245 | 
            +
                    spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
         | 
| 246 | 
            +
                    y, y_lengths = y.cuda(0), y_lengths.cuda(0)
         | 
| 247 | 
            +
                    speakers = speakers.cuda(0)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    # remove else
         | 
| 250 | 
            +
                    x = x[:1]
         | 
| 251 | 
            +
                    x_lengths = x_lengths[:1]
         | 
| 252 | 
            +
                    spec = spec[:1]
         | 
| 253 | 
            +
                    spec_lengths = spec_lengths[:1]
         | 
| 254 | 
            +
                    y = y[:1]
         | 
| 255 | 
            +
                    y_lengths = y_lengths[:1]
         | 
| 256 | 
            +
                    speakers = speakers[:1]
         | 
| 257 | 
            +
                    break
         | 
| 258 | 
            +
                  y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, max_len=1000)
         | 
| 259 | 
            +
                  y_hat_lengths = mask.sum([1,2]).long() * hps.data.hop_length
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                  mel = spec_to_mel_torch(
         | 
| 262 | 
            +
                    spec, 
         | 
| 263 | 
            +
                    hps.data.filter_length, 
         | 
| 264 | 
            +
                    hps.data.n_mel_channels, 
         | 
| 265 | 
            +
                    hps.data.sampling_rate,
         | 
| 266 | 
            +
                    hps.data.mel_fmin, 
         | 
| 267 | 
            +
                    hps.data.mel_fmax)
         | 
| 268 | 
            +
                  y_hat_mel = mel_spectrogram_torch(
         | 
| 269 | 
            +
                    y_hat.squeeze(1).float(),
         | 
| 270 | 
            +
                    hps.data.filter_length,
         | 
| 271 | 
            +
                    hps.data.n_mel_channels,
         | 
| 272 | 
            +
                    hps.data.sampling_rate,
         | 
| 273 | 
            +
                    hps.data.hop_length,
         | 
| 274 | 
            +
                    hps.data.win_length,
         | 
| 275 | 
            +
                    hps.data.mel_fmin,
         | 
| 276 | 
            +
                    hps.data.mel_fmax
         | 
| 277 | 
            +
                  )
         | 
| 278 | 
            +
                image_dict = {
         | 
| 279 | 
            +
                  "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
         | 
| 280 | 
            +
                }
         | 
| 281 | 
            +
                audio_dict = {
         | 
| 282 | 
            +
                  "gen/audio": y_hat[0,:,:y_hat_lengths[0]]
         | 
| 283 | 
            +
                }
         | 
| 284 | 
            +
                if global_step == 0:
         | 
| 285 | 
            +
                  image_dict.update({"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
         | 
| 286 | 
            +
                  audio_dict.update({"gt/audio": y[0,:,:y_lengths[0]]})
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                utils.summarize(
         | 
| 289 | 
            +
                  writer=writer_eval,
         | 
| 290 | 
            +
                  global_step=global_step, 
         | 
| 291 | 
            +
                  images=image_dict,
         | 
| 292 | 
            +
                  audios=audio_dict,
         | 
| 293 | 
            +
                  audio_sampling_rate=hps.data.sampling_rate
         | 
| 294 | 
            +
                )
         | 
| 295 | 
            +
                generator.train()
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                                       
         | 
| 298 | 
            +
            if __name__ == "__main__":
         | 
| 299 | 
            +
              main()
         |