XTTSv2-est / TTS /bin /extract_tts_spectrograms.py
Rasmus Lellep
initial commit
5a03f53
raw
history blame
10.2 kB
#!/usr/bin/env python3
"""Extract Mel spectrograms with teacher forcing."""
import argparse
import logging
import sys
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from trainer.generic_utils import count_parameters
from TTS.config import load_config
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
use_cuda = torch.cuda.is_available()
def parse_args(arg_list: Optional[list[str]]) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True)
return parser.parse_args(arg_list)
def setup_loader(config: BaseTTSConfig, ap: AudioProcessor, r, speaker_manager: SpeakerManager, samples) -> DataLoader:
tokenizer, _ = TTSTokenizer.init_from_config(config)
dataset = TTSDataset(
outputs_per_step=r,
compute_linear_spec=False,
samples=samples,
tokenizer=tokenizer,
ap=ap,
batch_group_size=0,
min_text_len=config.min_text_len,
max_text_len=config.max_text_len,
min_audio_len=config.min_audio_len,
max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=0,
use_noise_augment=False,
speaker_id_mapping=speaker_manager.name_to_id if config.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.embeddings if config.use_d_vector_file else None,
)
if config.use_phonemes and config.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(config.num_loader_workers)
dataset.preprocess_samples()
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=None,
num_workers=config.num_loader_workers,
pin_memory=False,
)
def set_filename(wav_path: str, out_path: Path) -> tuple[Path, Path, Path, Path]:
wav_name = Path(wav_path).stem
(out_path / "quant").mkdir(exist_ok=True, parents=True)
(out_path / "mel").mkdir(exist_ok=True, parents=True)
(out_path / "wav_gl").mkdir(exist_ok=True, parents=True)
(out_path / "wav").mkdir(exist_ok=True, parents=True)
wavq_path = out_path / "quant" / wav_name
mel_path = out_path / "mel" / wav_name
wav_gl_path = out_path / "wav_gl" / f"{wav_name}.wav"
out_wav_path = out_path / "wav" / f"{wav_name}.wav"
return wavq_path, mel_path, wav_gl_path, out_wav_path
def format_data(data):
# setup input data
text_input = data["token_id"]
text_lengths = data["token_id_lengths"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
item_idx = data["item_idxs"]
d_vectors = data["d_vectors"]
speaker_ids = data["speaker_ids"]
attn_mask = data["attns"]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda(non_blocking=True)
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
if d_vectors is not None:
d_vectors = d_vectors.cuda(non_blocking=True)
if attn_mask is not None:
attn_mask = attn_mask.cuda(non_blocking=True)
return (
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids,
d_vectors,
avg_text_length,
avg_spec_length,
attn_mask,
item_idx,
)
@torch.inference_mode()
def inference(
model_name: str,
model: BaseTTS,
ap: AudioProcessor,
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=None,
d_vectors=None,
) -> np.ndarray:
if model_name == "glow_tts":
speaker_c = None
if speaker_ids is not None:
speaker_c = speaker_ids
elif d_vectors is not None:
speaker_c = d_vectors
outputs = model.inference_with_MAS(
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
)
model_output = outputs["model_outputs"]
return model_output.detach().cpu().numpy()
if "tacotron" in model_name:
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input)
postnet_outputs = outputs["model_outputs"]
# normalize tacotron output
if model_name == "tacotron":
mel_specs = []
postnet_outputs = postnet_outputs.data.cpu().numpy()
for b in range(postnet_outputs.shape[0]):
postnet_output = postnet_outputs[b]
mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T))
return torch.stack(mel_specs).cpu().numpy()
if model_name == "tacotron2":
return postnet_outputs.detach().cpu().numpy()
msg = f"Model not supported: {model_name}"
raise ValueError(msg)
def extract_spectrograms(
model_name: str,
data_loader: DataLoader,
model: BaseTTS,
ap: AudioProcessor,
output_path: Path,
quantize_bits: int = 0,
save_audio: bool = False,
debug: bool = False,
metadata_name: str = "metadata.txt",
) -> None:
model.eval()
export_metadata = []
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
# format data
(
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids,
d_vectors,
_,
_,
_,
item_idx,
) = format_data(data)
model_output = inference(
model_name,
model,
ap,
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids,
d_vectors,
)
for idx in range(text_input.shape[0]):
wav_file_path = item_idx[idx]
wav = ap.load_wav(wav_file_path)
wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
# quantize and save wav
if quantize_bits > 0:
wavq = quantize(wav, quantize_bits)
np.save(wavq_path, wavq)
# save TTS mel
mel = model_output[idx]
mel_length = mel_lengths[idx]
mel = mel[:mel_length, :].T
np.save(mel_path, mel)
export_metadata.append([wav_file_path, mel_path])
if save_audio:
ap.save_wav(wav, wav_path)
if debug:
print("Audio for debug saved at:", wav_gl_path)
wav = ap.inv_melspectrogram(mel)
ap.save_wav(wav, wav_gl_path)
with (output_path / metadata_name).open("w") as f:
for data in export_metadata:
f.write(f"{data[0] / data[1]}.npy\n")
def main(arg_list: Optional[list[str]] = None) -> None:
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())
args = parse_args(arg_list)
config = load_config(args.config_path)
config.audio.trim_silence = False
# Audio processor
ap = AudioProcessor(**config.audio)
# load data instances
meta_data_train, meta_data_eval = load_tts_samples(
config.datasets,
eval_split=args.eval,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# use eval and training partitions
meta_data = meta_data_train + meta_data_eval
# init speaker manager
if config.use_speaker_embedding:
speaker_manager = SpeakerManager(data_items=meta_data)
elif config.use_d_vector_file:
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
else:
speaker_manager = None
# setup model
model = setup_model(config)
# restore model
model.load_checkpoint(config, args.checkpoint_path, eval=True)
if use_cuda:
model.cuda()
num_params = count_parameters(model)
print(f"\n > Model has {num_params} parameters", flush=True)
# set r
r = 1 if config.model.lower() == "glow_tts" else model.decoder.r
own_loader = setup_loader(config, ap, r, speaker_manager, meta_data)
extract_spectrograms(
config.model.lower(),
own_loader,
model,
ap,
Path(args.output_path),
quantize_bits=args.quantize_bits,
save_audio=args.save_audio,
debug=args.debug,
metadata_name="metadata.txt",
)
sys.exit(0)
if __name__ == "__main__":
main()