import json import string from dataclasses import asdict, dataclass from pathlib import Path import torch from safetensors.torch import load_file from torch.nn.utils.rnn import pad_sequence from codec.models import PatchVAE from tts.model.cache_utils import FLACache from tts.text_processor import BasicTextProcessor from tts.tools import sequence_mask from tts.tts import ARTTSModel @dataclass class VelocityHeadSamplingParams: """ Velocity head sampling parameters Attributes: cfg (float): CFG factor against unconditional prediction. cfg_ref (float): CFG factor against a reference (to be used with a cache of size 2*batch_size and unfold). temperature (float): scale factor of z0 ~ 𝒩(0,1) num_steps (int): number of ODE steps solver (str): parameter passed to NeuralODE sensitivity (str): parameter passed to NeuralODE """ cfg: float = 1.3 cfg_ref: float = 1.5 temperature: float = 0.9 num_steps: int = 13 solver: str = "euler" sensitivity: str = "adjoint" @dataclass class PatchVAESamplingParams: """ PatchVAE sampling parameters Attributes: cfg (float): CFG factor against unconditional prediction. temperature (float): scale factor of z0 ~ 𝒩(0,1) num_steps (int): number of ODE steps solver (str): parameter passed to NeuralODE sensitivity (str): parameter passed to NeuralODE """ cfg: float = 2.0 temperature: float = 1.0 num_steps: int = 10 solver: str = "euler" sensitivity: str = "adjoint" class PardiSpeech: tts: ARTTSModel patchvae: PatchVAE text_processor: BasicTextProcessor def __init__( self, tts: ARTTSModel, patchvae: PatchVAE, text_processor: BasicTextProcessor, ): self.tts = tts self.patchvae = patchvae self.text_processor = text_processor @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, map_location: str = "cpu", ): if Path(pretrained_model_name_or_path).exists(): path = pretrained_model_name_or_path else: from huggingface_hub import snapshot_download path = snapshot_download(pretrained_model_name_or_path) with open(Path(path) / "config.json", "r") as f: config = json.load(f) artts_model, artts_config = ARTTSModel.instantiate_from_config(config) state_dict = load_file( Path(path) / "model.st", device=map_location, ) artts_model.load_state_dict(state_dict, assign=True) patchvae = PatchVAE.from_pretrained( artts_config.patchvae_path, map_location=map_location, ) text_processor = BasicTextProcessor( str(Path(path) / "pretrained_tokenizer.json") ) return cls(artts_model, patchvae, text_processor) def encode_reference(self, wav: torch.Tensor, sr: int): import torchaudio new_freq = self.patchvae.wavvae.sampling_rate wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=new_freq) return self.patchvae.encode(wav) @property def sampling_rate(self): return self.patchvae.wavvae.sampling_rate def text_to_speech( self, text: str, prefix: tuple[str, torch.Tensor] | None = None, patchvae_sampling_params: PatchVAESamplingParams | None = None, velocity_head_sampling_params: VelocityHeadSamplingParams | None = None, prefix_separator: str = ". ", max_seq_len: int = 600, stop_threshold: float = 0.5, cache: FLACache | None = None, **kwargs, ): """ Parameters ---------- text: str The text to synthesize. prefix: tuple[str, torch.Tensor] | None A pair (text, speech) consisting of a reference speech excerpt encoded (see encode_reference) and its corresponding text transcription. Synthesis is performed by continuing the prefix. If no prefix is given, the first frame is randomly sampled. patchvae_sampling_params: PatchVAESamplingParams PatchVAE sampling parameters velocity_head_sampling_params: VelocityHeadSamplingParams VelocityHead sampling parameters (AR sampling) prefix_separator: str The separator that joins the prefix text to the target text. max_seq_len: int The maximum number of latent to generate. stop_threshold: float Threshold value at which AR prediction stops. """ device = next(self.tts.parameters()).device if type(text) is str: text = [text] if prefix is not None: prefix_text, prefix_speech = prefix prefix_text = prefix_text.strip().rstrip(string.punctuation) if prefix_text != "": text = [prefix_text + prefix_separator + t for t in text] prefix_speech = prefix_speech.repeat(len(text), 1, 1) else: _, audio_latent_sz = self.tts.audio_embd.weight.shape prefix_speech = torch.randn(len(text), 1, audio_latent_sz, device=device) # if self.bos: # text = "[BOS]" + text # if self.eos: # text = text + "[EOS]" text_ids = [torch.LongTensor(self.text_processor(x + "[EOS]")) for x in text] text_pre_mask = sequence_mask(torch.tensor([x.shape[0] for x in text_ids])).to(device) text_mask = text_pre_mask[:, None] * text_pre_mask[..., None] crossatt_mask = text_pre_mask[:, None,None] text_ids = pad_sequence(text_ids, batch_first=True) if velocity_head_sampling_params is None: velocity_head_sampling_params = VelocityHeadSamplingParams() if patchvae_sampling_params is None: patchvae_sampling_params = PatchVAESamplingParams() with torch.inference_mode(): _, predictions = self.tts.generate( text_ids.to(device), text_mask=text_mask, crossatt_mask=crossatt_mask, prefix=prefix_speech.to(device), max_seq_len=max_seq_len, sampling_params=asdict(velocity_head_sampling_params), stop_threshold=stop_threshold, cache=cache, device=device, **kwargs, ) wavs = [self.patchvae.decode( p, **asdict(patchvae_sampling_params), ) for p in predictions] return wavs, predictions