|
|
|
import os.path as osp |
|
import random |
|
import numpy as np |
|
import random |
|
import soundfile as sf |
|
import librosa |
|
|
|
import torch |
|
import torchaudio |
|
import torch.utils.data |
|
import torch.distributed as dist |
|
from multiprocessing import Pool |
|
|
|
import logging |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
import pandas as pd |
|
|
|
class TextCleaner: |
|
def __init__(self, symbol_dict, debug=True): |
|
self.word_index_dictionary = symbol_dict |
|
self.debug = debug |
|
def __call__(self, text): |
|
indexes = [] |
|
for char in text: |
|
try: |
|
indexes.append(self.word_index_dictionary[char]) |
|
except KeyError as e: |
|
if self.debug: |
|
print("\nWARNING UNKNOWN IPA CHARACTERS/LETTERS: ", char) |
|
print("To ignore set 'debug' to false in the config") |
|
continue |
|
return indexes |
|
|
|
np.random.seed(1) |
|
random.seed(1) |
|
SPECT_PARAMS = { |
|
"n_fft": 2048, |
|
"win_length": 1200, |
|
"hop_length": 300 |
|
} |
|
MEL_PARAMS = { |
|
"n_mels": 80, |
|
} |
|
|
|
to_mel = torchaudio.transforms.MelSpectrogram( |
|
n_mels=80, n_fft=2048, win_length=1200, hop_length=300) |
|
mean, std = -4, 4 |
|
|
|
def preprocess(wave): |
|
wave_tensor = torch.from_numpy(wave).float() |
|
mel_tensor = to_mel(wave_tensor) |
|
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std |
|
return mel_tensor |
|
|
|
class FilePathDataset(torch.utils.data.Dataset): |
|
def __init__(self, |
|
data_list, |
|
root_path, |
|
symbol_dict, |
|
sr=24000, |
|
data_augmentation=False, |
|
validation=False, |
|
debug=True |
|
): |
|
|
|
_data_list = [l.strip().split('|') for l in data_list] |
|
self.data_list = _data_list |
|
self.text_cleaner = TextCleaner(symbol_dict, debug) |
|
self.sr = sr |
|
|
|
self.df = pd.DataFrame(self.data_list) |
|
|
|
self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS) |
|
|
|
self.mean, self.std = -4, 4 |
|
self.data_augmentation = data_augmentation and (not validation) |
|
self.max_mel_length = 192 |
|
|
|
self.root_path = root_path |
|
|
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data_list[idx] |
|
path = data[0] |
|
|
|
wave, text_tensor = self._load_tensor(data) |
|
|
|
mel_tensor = preprocess(wave).squeeze() |
|
|
|
acoustic_feature = mel_tensor.squeeze() |
|
length_feature = acoustic_feature.size(1) |
|
acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)] |
|
|
|
return acoustic_feature, text_tensor, path, wave |
|
|
|
def _load_tensor(self, data): |
|
wave_path, text = data |
|
wave, sr = sf.read(osp.join(self.root_path, wave_path)) |
|
if wave.shape[-1] == 2: |
|
wave = wave[:, 0].squeeze() |
|
if sr != 24000: |
|
wave = librosa.resample(wave, orig_sr=sr, target_sr=24000) |
|
print(wave_path, sr) |
|
|
|
|
|
wave = np.concatenate([np.zeros([12000]), wave, np.zeros([12000])], axis=0) |
|
|
|
text = self.text_cleaner(text) |
|
|
|
text.insert(0, 0) |
|
text.append(0) |
|
|
|
text = torch.LongTensor(text) |
|
|
|
return wave, text |
|
|
|
def _load_data(self, data): |
|
wave, text_tensor = self._load_tensor(data) |
|
mel_tensor = preprocess(wave).squeeze() |
|
|
|
mel_length = mel_tensor.size(1) |
|
if mel_length > self.max_mel_length: |
|
random_start = np.random.randint(0, mel_length - self.max_mel_length) |
|
mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length] |
|
|
|
return mel_tensor |
|
|
|
|
|
class Collater(object): |
|
""" |
|
Args: |
|
adaptive_batch_size (bool): if true, decrease batch size when long data comes. |
|
""" |
|
|
|
def __init__(self, return_wave=False): |
|
self.text_pad_index = 0 |
|
self.min_mel_length = 192 |
|
self.max_mel_length = 192 |
|
self.return_wave = return_wave |
|
|
|
|
|
def __call__(self, batch): |
|
batch_size = len(batch) |
|
|
|
|
|
lengths = [b[0].shape[1] for b in batch] |
|
batch_indexes = np.argsort(lengths)[::-1] |
|
batch = [batch[bid] for bid in batch_indexes] |
|
|
|
nmels = batch[0][0].size(0) |
|
max_mel_length = max([b[0].shape[1] for b in batch]) |
|
max_text_length = max([b[1].shape[0] for b in batch]) |
|
|
|
mels = torch.zeros((batch_size, nmels, max_mel_length)).float() |
|
texts = torch.zeros((batch_size, max_text_length)).long() |
|
|
|
input_lengths = torch.zeros(batch_size).long() |
|
output_lengths = torch.zeros(batch_size).long() |
|
paths = ['' for _ in range(batch_size)] |
|
waves = [None for _ in range(batch_size)] |
|
|
|
for bid, (mel, text, path, wave) in enumerate(batch): |
|
mel_size = mel.size(1) |
|
text_size = text.size(0) |
|
mels[bid, :, :mel_size] = mel |
|
texts[bid, :text_size] = text |
|
input_lengths[bid] = text_size |
|
output_lengths[bid] = mel_size |
|
paths[bid] = path |
|
|
|
waves[bid] = wave |
|
|
|
return waves, texts, input_lengths, mels, output_lengths |
|
|
|
|
|
def get_length(wave_path, root_path): |
|
info = sf.info(osp.join(root_path, wave_path)) |
|
return info.frames * (24000 / info.samplerate) |
|
|
|
def build_dataloader(path_list, |
|
root_path, |
|
symbol_dict, |
|
validation=False, |
|
batch_size=4, |
|
num_workers=1, |
|
device='cpu', |
|
collate_config={}, |
|
dataset_config={}): |
|
|
|
dataset = FilePathDataset(path_list, root_path, symbol_dict, validation=validation, **dataset_config) |
|
collate_fn = Collater(**collate_config) |
|
|
|
print("Getting sample lengths...") |
|
|
|
num_processes = num_workers * 2 |
|
if num_processes != 0: |
|
list_of_tuples = [(d[0], root_path) for d in dataset.data_list] |
|
with Pool(processes=num_processes) as pool: |
|
sample_lengths = pool.starmap(get_length, list_of_tuples, chunksize=16) |
|
else: |
|
sample_lengths = [] |
|
for d in dataset.data_list: |
|
sample_lengths.append(get_length(d[0], root_path)) |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
num_workers=num_workers, |
|
batch_sampler=BatchSampler( |
|
sample_lengths, |
|
batch_size, |
|
shuffle=(not validation), |
|
drop_last=(not validation), |
|
num_replicas=1, |
|
rank=0, |
|
), |
|
collate_fn=collate_fn, |
|
pin_memory=(device != "cpu"), |
|
) |
|
|
|
return data_loader |
|
|
|
|
|
class BatchSampler(torch.utils.data.Sampler): |
|
def __init__( |
|
self, |
|
sample_lengths, |
|
batch_sizes, |
|
num_replicas=None, |
|
rank=None, |
|
shuffle=True, |
|
drop_last=False, |
|
): |
|
self.batch_sizes = batch_sizes |
|
if num_replicas is None: |
|
self.num_replicas = dist.get_world_size() |
|
else: |
|
self.num_replicas = num_replicas |
|
if rank is None: |
|
self.rank = dist.get_rank() |
|
else: |
|
self.rank = rank |
|
self.shuffle = shuffle |
|
self.drop_last = drop_last |
|
|
|
self.time_bins = {} |
|
self.epoch = 0 |
|
self.total_len = 0 |
|
self.last_bin = None |
|
|
|
for i in range(len(sample_lengths)): |
|
bin_num = self.get_time_bin(sample_lengths[i]) |
|
if bin_num != -1: |
|
if bin_num not in self.time_bins: |
|
self.time_bins[bin_num] = [] |
|
self.time_bins[bin_num].append(i) |
|
|
|
for key in self.time_bins.keys(): |
|
val = self.time_bins[key] |
|
total_batch = self.batch_sizes * num_replicas |
|
self.total_len += len(val) // total_batch |
|
if not self.drop_last and len(val) % total_batch != 0: |
|
self.total_len += 1 |
|
|
|
def __iter__(self): |
|
sampler_order = list(self.time_bins.keys()) |
|
sampler_indices = [] |
|
|
|
if self.shuffle: |
|
sampler_indices = torch.randperm(len(sampler_order)).tolist() |
|
else: |
|
sampler_indices = list(range(len(sampler_order))) |
|
|
|
for index in sampler_indices: |
|
key = sampler_order[index] |
|
current_bin = self.time_bins[key] |
|
dist = torch.utils.data.distributed.DistributedSampler( |
|
current_bin, |
|
num_replicas=self.num_replicas, |
|
rank=self.rank, |
|
shuffle=self.shuffle, |
|
drop_last=self.drop_last, |
|
) |
|
dist.set_epoch(self.epoch) |
|
sampler = torch.utils.data.sampler.BatchSampler( |
|
dist, self.batch_sizes, self.drop_last |
|
) |
|
for item_list in sampler: |
|
self.last_bin = key |
|
yield [current_bin[i] for i in item_list] |
|
|
|
def __len__(self): |
|
return self.total_len |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
def get_time_bin(self, sample_count): |
|
result = -1 |
|
frames = sample_count // 300 |
|
if frames >= 20: |
|
result = (frames - 20) // 20 |
|
return result |