pardi-speech / tts /tts.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
raw
history blame
18.5 kB
import json
from pathlib import Path
from typing import Any
import numpy as np
import torch
from safetensors.torch import load_file
from torch import nn
from tqdm import tqdm
from tts.model.config import TTSConfig
from tts.model.prediction_head import (ContinuousHead, LogitsHead,
StopPredictionHead, VelocityHead)
from tts.model.registry import DECODER_REGISTRY, ENCODER_REGISTRY
from tts.tools import path_matrix, widen_alignment
def collect_heads(cache, selected_heads, last=True):
if last:
return torch.cat(
[
cache[layer]["crossatt_weights"][:, [head], -1]
for layer, head in selected_heads
],
dim=1,
)[:, :, None]
else:
return torch.cat(
[
cache[layer]["crossatt_weights"][:, [head]]
for layer, head in selected_heads
],
dim=1,
)
def mask_from_abs_pos(abs_pos, text_len, expand_factor, width=(5, 1)):
exp_ca_mask = path_matrix(abs_pos, text_len)
exp_ca_mask = widen_alignment(exp_ca_mask, width=width, axis="S")
exp_ca_mask = expand(exp_ca_mask, expand_factor)
return exp_ca_mask
def expand(x, r):
b, n, d = x.shape
x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d)
return x
class ARTTSModel(nn.Module):
def __init__(self, cfg: TTSConfig):
super().__init__()
self.text_embd = nn.Embedding(cfg.text_vocab_size, cfg.dim)
if cfg.audio_input_type == "discrete":
self.audio_embd = nn.Embedding(cfg.audio_vocab_size, cfg.dim)
self.prediction_head = LogitsHead(cfg.decoder_cfg.dim, cfg.audio_vocab_size)
elif cfg.audio_input_type == "continuous" and cfg.continuous_diffusion:
self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim)
self.prediction_head = VelocityHead(
cfg.decoder_cfg.dim,
cfg.audio_embed_size,
cfg.diffusion_head_num_layers,
)
elif cfg.audio_input_type == "continuous":
self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim)
self.prediction_head = ContinuousHead(
cfg.decoder_cfg.dim,
cfg.audio_embed_size,
)
self.text_encoder = ENCODER_REGISTRY[cfg.encoder_cfg.name](cfg.encoder_cfg)
self.audio_decoder = DECODER_REGISTRY[cfg.decoder_cfg.name](cfg.decoder_cfg)
self.stop_token_embd = None
self.stop_prediction_head = None
if cfg.stop_prediction_head:
if cfg.stop_token_embd:
self.stop_token_embd = nn.Embedding(2, cfg.dim, padding_idx=0)
self.stop_prediction_head = StopPredictionHead(cfg.dim)
if cfg.num_sink_tokens > 0:
self.sink_tokens = nn.Parameter(
torch.randn(cfg.num_sink_tokens, cfg.dim) * 0.02, requires_grad=True
)
else:
self.sink_tokens = None
self.disabled_crossatt_head_idx = cfg.disabled_crossatt_head_idx
@property
def num_sink_tokens(self):
if self.sink_tokens is None:
return 0
else:
n_sink, _ = self.sink_tokens.shape
return n_sink
@classmethod
def instantiate_from_config(cls, config):
for k in config.keys():
if k == "decoder_cfg":
config[k] = DECODER_REGISTRY[config[k]["name"]].config(**config[k])
if k == "encoder_cfg":
config[k] = ENCODER_REGISTRY[config[k]["name"]].config(**config[k])
config = TTSConfig(**config)
return ARTTSModel(config), config
@classmethod
def from_pretrained_local(
cls,
path: str,
config_filename: str = "config.json",
model_filename: str = "model.st",
device: str = "cpu",
):
with open(Path(path) / config_filename, "r") as f:
config = json.load(f)
model, config = cls.instantiate_from_config(config)
state_dict = load_file(Path(path) / model_filename, device=device)
model.load_state_dict(state_dict)
return model
def _get_query(self, x: torch.Tensor, *args):
input_audio_embd = self.audio_embd(x)
return self.audio_decoder._get_query(input_audio_embd, *args)
def forward(
self,
text_ids: torch.LongTensor,
audio_inputs: torch.Tensor,
text_mask: torch.Tensor | None = None,
audio_mask: torch.Tensor | None = None,
stop_tokens: torch.Tensor | None = None,
text_stop_tokens: torch.Tensor | None = None,
text_rel_pos: torch.Tensor | None = None,
crossatt_mask: torch.Tensor | None = None,
crossatt_rel_pos: torch.Tensor | None = None,
n_first_layers: int | None = None,
cache: Any | None = None,
):
input_text_embd = self.text_embd(text_ids)
input_audio_embd = self.audio_embd(audio_inputs[:, :-1])
if self.stop_token_embd is not None:
if stop_tokens is not None:
stop_tokens_embd = self.stop_token_embd(stop_tokens)
input_audio_embd += stop_tokens_embd[:, :-1]
text_hidden_states = self.text_encoder(
input_text_embd,
mask=text_mask,
text_rel_pos=text_rel_pos,
)
if self.disabled_crossatt_head_idx is not None and crossatt_mask is not None:
crossatt_mask_list = []
n_sink, _ = self.sink_tokens.shape
for layer in self.audio_decoder.decoder_layers:
if layer.crossatt is not None:
h = layer.crossatt.heads
crossatt_layer_mask = (
crossatt_mask.unsqueeze(1).repeat(1, h, 1, 1).clone()
)
crossatt_layer_mask = torch.nn.functional.pad(
crossatt_layer_mask,
(n_sink, 0),
value=True,
)
crossatt_mask_list.append(crossatt_layer_mask[:, :, :-1])
else:
crossatt_mask_list.append(None)
for layer, head in self.disabled_crossatt_head_idx:
crossatt_mask_list[layer][:, head, :, n_sink:] = False
crossatt_mask = crossatt_mask_list
else:
if self.sink_tokens is not None:
n_sink, _ = self.sink_tokens.shape
if crossatt_mask is not None:
crossatt_mask = torch.nn.functional.pad(
crossatt_mask,
(n_sink, 0),
value=True,
)
crossatt_mask = crossatt_mask[:, :-1]
if self.sink_tokens is not None:
sink_tokens = self.sink_tokens[None, :].repeat(
text_hidden_states.shape[0], 1, 1
)
text_hidden_states = torch.cat(
(sink_tokens, text_hidden_states),
dim=1,
)
if n_first_layers is not None:
pre_logits = self.audio_decoder.forward_first_n_layers(
text_hidden_states,
input_audio_embd,
n_first_layers,
crossatt_mask=crossatt_mask,
)
else:
pre_logits = self.audio_decoder(
text_hidden_states,
input_audio_embd,
crossatt_mask=crossatt_mask,
cache=cache,
)
return pre_logits
def generate(
self,
text_ids: torch.LongTensor,
prefix: torch.Tensor,
text_mask: torch.Tensor | None = None,
crossatt_mask: torch.Tensor | None = None,
text_rel_pos: torch.LongTensor | None = None,
teacher_force: torch.Tensor | None = None,
unfold_ref: bool = False,
max_seq_len: int = 200,
device: str = "cuda",
sampling_params: dict | None = None,
stop_threshold: float = 0.5,
cache: Any | None = None,
do_not_stop: bool = False,
):
if sampling_params is None:
sampling_params = {}
if text_ids.ndim == 1:
text_ids = text_ids.unsqueeze(0)
batch_size = text_ids.shape[0]
input_text_embd = self.text_embd(text_ids)
text_hidden_states = self.text_encoder(
input_text_embd,
text_rel_pos=text_rel_pos,
mask=text_mask,
)
prefix_embd = self.audio_embd(prefix)
if self.sink_tokens is not None:
sink_tokens = self.sink_tokens[None, :].repeat(
text_hidden_states.shape[0], 1, 1
)
text_hidden_states = torch.cat(
(sink_tokens, text_hidden_states),
dim=1,
)
if crossatt_mask is not None:
n_sink, _ = self.sink_tokens.shape
crossatt_mask = torch.nn.functional.pad(
crossatt_mask,
(n_sink, 0),
value=True,
)
if cache is None:
cache = self.audio_decoder.init_cache(
max_seq_len + prefix_embd.shape[1], device
)
stop_status = torch.zeros(batch_size, device=device).bool()
stop_idx = torch.ones(batch_size, device=device).long()*max_seq_len
preds = []
pre_prediction = self.audio_decoder.prefill(
text_hidden_states,
prefix_embd,
cache=cache,
)
prediction = self.prediction_head.predict(
pre_prediction[:, [-1]], **sampling_params
)
prediction_embd = self.audio_embd(prediction)
for i in tqdm(range(max_seq_len)):
pre_prediction = self.audio_decoder.decode_one(
text_hidden_states,
prediction_embd,
cache,
crossatt_mask=crossatt_mask,
)
if unfold_ref:
pre_prediction, pre_prediction_ref = pre_prediction.chunk(2)
else:
pre_prediction_ref = None
prediction = self.prediction_head.predict(pre_prediction,
pre_prediction_ref=pre_prediction_ref,
**sampling_params,)
prediction_embd = self.audio_embd(prediction)
if unfold_ref:
prediction_embd = prediction_embd.repeat(2, 1, 1)
if teacher_force is not None:
b, n, d = teacher_force.shape
if i < n:
prediction_embd = self.audio_embd(teacher_force[:, [i]])
preds.append(prediction)
if self.stop_prediction_head is not None:
stop_pred = self.stop_prediction_head(pre_prediction).squeeze(1,2)
stop_signal = stop_pred > stop_threshold
stop_status += stop_signal
stop_idx[stop_signal * stop_idx > i] = i
if stop_status.prod():
if self.stop_token_embd is not None:
st_embd = self.stop_token_embd(
torch.ones(1, 1, device=device).int()
)
prediction_embd += st_embd
if not do_not_stop:
break
else:
print(f"STOP: {i}")
full_prediction = torch.cat(preds, dim=1)
full_prediction = [x[:stop_idx[i]][None] for i, x in enumerate(full_prediction.unbind())]
return cache, full_prediction
"""
def generate_with_playhead(
self,
text_ids: torch.LongTensor,
prefix: torch.Tensor,
playhead_model: PlayHead,
selected_heads_idx: list[tuple[int, int]],
text_stop_tokens: torch.LongTensor | None = None,
text_mask: torch.Tensor | None = None,
text_rel_pos: torch.LongTensor | None = None,
teacher_force: torch.Tensor | None = None,
max_seq_len: int = 200,
device: str = "cuda",
sampling_params: dict | None = None,
stop_threshold: float = 0.5,
do_not_stop: bool = False,
width: tuple[int, int] = (5, 1),
abs_pos_start: int = 0,
stop_end_distance_threshold: int = 5,
):
if sampling_params is None:
sampling_params = {}
if text_ids.ndim == 1:
text_ids = text_ids.unsqueeze(0)
input_text_embd = self.text_embd(text_ids)
if self.text_stop_token_embd is not None:
if text_stop_tokens is not None:
text_stop_tokens_embd = self.text_stop_token_embd(text_stop_tokens)
input_text_embd += text_stop_tokens_embd
text_hidden_states = self.text_encoder(
input_text_embd,
text_rel_pos=text_rel_pos,
mask=text_mask,
)
prefix_embd = self.audio_embd(prefix)
text_len = text_hidden_states.shape[1]
if self.sink_tokens is not None:
sink_tokens = self.sink_tokens[None, :].repeat(
text_hidden_states.shape[0], 1, 1
)
text_hidden_states = torch.cat(
(sink_tokens, text_hidden_states),
dim=1,
)
cache = self.audio_decoder.init_cache(max_seq_len, device)
preds = []
pre_prediction = self.audio_decoder.prefill(
text_hidden_states,
prefix_embd,
cache=cache,
)
text_freqs = None
prediction = self.prediction_head.predict(
pre_prediction[:, [-1]], **sampling_params
)
prediction_embd = self.audio_embd(prediction)
preds.append(prediction)
playhead_cache = playhead_model.init_cache()
previous_position = torch.zeros(1, device=device)
abs_pos = torch.ones(1, 1, device=device).long() * abs_pos_start
selected_heads_frame = collect_heads(cache, selected_heads_idx, last=False)
selected_heads_frame = selected_heads_frame.sum(1).transpose(-1, -2)
pos_preds = []
steps = []
expand_crossatt_mask = []
for i in tqdm(range(selected_heads_frame.shape[2])):
pred, step = playhead_model.predict(
selected_heads_frame[..., [i]],
cache=playhead_cache,
previous_position=previous_position,
)
previous_position = pred
abs_pos += step
pos_preds.append(pred)
steps.append(step)
exp_ca_mask = mask_from_abs_pos(
abs_pos,
(text_len // playhead_model.avg_pool_stride) + 1,
playhead_model.avg_pool_stride,
width=width,
)
exp_ca_mask = torch.nn.functional.pad(
exp_ca_mask, (self.num_sink_tokens, 0), value=True
).bool()[..., : text_len + self.num_sink_tokens]
expand_crossatt_mask.append(exp_ca_mask)
print("starting at: ", abs_pos.item())
# pos_pred, step = playhead_model.predict(
# selected_heads_frame,
# cache=playhead_cache,
# previous_position=previous_position,
# )
# previous_position = pos_pred[:, [-1]]
# abs_pos += step
# exp_ca_mask = mask_from_abs_pos(
# abs_pos,
# (text_len // playhead_model.avg_pool_stride) + 1,
# playhead_model.avg_pool_stride,
# width=width,
# )
# expand_crossatt_mask.append(exp_ca_mask)
# steps.append(step)
# pos_preds.append(pos_pred)
progress_bar = tqdm(range(max_seq_len))
for i in progress_bar:
pre_prediction = self.audio_decoder.decode_one(
text_hidden_states,
prediction_embd,
cache,
# text_freqs=text_freqs,
crossatt_mask=exp_ca_mask,
)
prediction = self.prediction_head.predict(pre_prediction, **sampling_params)
prediction_embd = self.audio_embd(prediction)
if teacher_force is not None:
b, n, d = teacher_force.shape
if i < n:
prediction_embd = self.audio_embd(teacher_force[:, [i]])
### PLAYHEAD ========================
selected_heads_frame = (
collect_heads(cache, selected_heads_idx).sum(1).transpose(-1, -2)
)
pos_pred, step = playhead_model.predict(
selected_heads_frame,
cache=playhead_cache,
previous_position=previous_position,
)
previous_position = pos_pred
abs_pos += step
exp_ca_mask = mask_from_abs_pos(
abs_pos,
(text_len // playhead_model.avg_pool_stride) + 1,
playhead_model.avg_pool_stride,
width=width,
)
exp_ca_mask = torch.nn.functional.pad(
exp_ca_mask, (self.num_sink_tokens, 0), value=True
).bool()[..., : text_len + self.num_sink_tokens]
expand_crossatt_mask.append(exp_ca_mask)
steps.append(step)
pos_preds.append(pos_pred)
# =================================
preds.append(prediction)
if self.stop_prediction_head is not None:
stop_pred = self.stop_prediction_head(pre_prediction)
if stop_pred > stop_threshold:
dist = np.abs(
abs_pos.cpu().item() * playhead_model.avg_pool_stride - text_len
)
progress_bar.set_postfix(
{"stop": f"pos: {abs_pos.cpu().item()}; dist{dist}"}
)
if dist < stop_end_distance_threshold and not do_not_stop:
break
progress_bar.set_postfix({"position": abs_pos.cpu().item()})
full_prediction = torch.cat(preds, dim=1)
expand_crossatt_mask = torch.cat(expand_crossatt_mask, dim=1)
print(expand_crossatt_mask.shape)
return cache, full_prediction, expand_crossatt_mask, steps, pos_preds
"""