Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from safetensors.torch import load_file | |
| from torch import nn | |
| from tqdm import tqdm | |
| from tts.model.config import TTSConfig | |
| from tts.model.prediction_head import (ContinuousHead, LogitsHead, | |
| StopPredictionHead, VelocityHead) | |
| from tts.model.registry import DECODER_REGISTRY, ENCODER_REGISTRY | |
| from tts.tools import path_matrix, widen_alignment | |
| def collect_heads(cache, selected_heads, last=True): | |
| if last: | |
| return torch.cat( | |
| [ | |
| cache[layer]["crossatt_weights"][:, [head], -1] | |
| for layer, head in selected_heads | |
| ], | |
| dim=1, | |
| )[:, :, None] | |
| else: | |
| return torch.cat( | |
| [ | |
| cache[layer]["crossatt_weights"][:, [head]] | |
| for layer, head in selected_heads | |
| ], | |
| dim=1, | |
| ) | |
| def mask_from_abs_pos(abs_pos, text_len, expand_factor, width=(5, 1)): | |
| exp_ca_mask = path_matrix(abs_pos, text_len) | |
| exp_ca_mask = widen_alignment(exp_ca_mask, width=width, axis="S") | |
| exp_ca_mask = expand(exp_ca_mask, expand_factor) | |
| return exp_ca_mask | |
| def expand(x, r): | |
| b, n, d = x.shape | |
| x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d) | |
| return x | |
| class ARTTSModel(nn.Module): | |
| def __init__(self, cfg: TTSConfig): | |
| super().__init__() | |
| self.text_embd = nn.Embedding(cfg.text_vocab_size, cfg.dim) | |
| if cfg.audio_input_type == "discrete": | |
| self.audio_embd = nn.Embedding(cfg.audio_vocab_size, cfg.dim) | |
| self.prediction_head = LogitsHead(cfg.decoder_cfg.dim, cfg.audio_vocab_size) | |
| elif cfg.audio_input_type == "continuous" and cfg.continuous_diffusion: | |
| self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim) | |
| self.prediction_head = VelocityHead( | |
| cfg.decoder_cfg.dim, | |
| cfg.audio_embed_size, | |
| cfg.diffusion_head_num_layers, | |
| ) | |
| elif cfg.audio_input_type == "continuous": | |
| self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim) | |
| self.prediction_head = ContinuousHead( | |
| cfg.decoder_cfg.dim, | |
| cfg.audio_embed_size, | |
| ) | |
| self.text_encoder = ENCODER_REGISTRY[cfg.encoder_cfg.name](cfg.encoder_cfg) | |
| self.audio_decoder = DECODER_REGISTRY[cfg.decoder_cfg.name](cfg.decoder_cfg) | |
| self.stop_token_embd = None | |
| self.stop_prediction_head = None | |
| if cfg.stop_prediction_head: | |
| if cfg.stop_token_embd: | |
| self.stop_token_embd = nn.Embedding(2, cfg.dim, padding_idx=0) | |
| self.stop_prediction_head = StopPredictionHead(cfg.dim) | |
| if cfg.num_sink_tokens > 0: | |
| self.sink_tokens = nn.Parameter( | |
| torch.randn(cfg.num_sink_tokens, cfg.dim) * 0.02, requires_grad=True | |
| ) | |
| else: | |
| self.sink_tokens = None | |
| self.disabled_crossatt_head_idx = cfg.disabled_crossatt_head_idx | |
| def num_sink_tokens(self): | |
| if self.sink_tokens is None: | |
| return 0 | |
| else: | |
| n_sink, _ = self.sink_tokens.shape | |
| return n_sink | |
| def instantiate_from_config(cls, config): | |
| for k in config.keys(): | |
| if k == "decoder_cfg": | |
| config[k] = DECODER_REGISTRY[config[k]["name"]].config(**config[k]) | |
| if k == "encoder_cfg": | |
| config[k] = ENCODER_REGISTRY[config[k]["name"]].config(**config[k]) | |
| config = TTSConfig(**config) | |
| return ARTTSModel(config), config | |
| def from_pretrained_local( | |
| cls, | |
| path: str, | |
| config_filename: str = "config.json", | |
| model_filename: str = "model.st", | |
| device: str = "cpu", | |
| ): | |
| with open(Path(path) / config_filename, "r") as f: | |
| config = json.load(f) | |
| model, config = cls.instantiate_from_config(config) | |
| state_dict = load_file(Path(path) / model_filename, device=device) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def _get_query(self, x: torch.Tensor, *args): | |
| input_audio_embd = self.audio_embd(x) | |
| return self.audio_decoder._get_query(input_audio_embd, *args) | |
| def forward( | |
| self, | |
| text_ids: torch.LongTensor, | |
| audio_inputs: torch.Tensor, | |
| text_mask: torch.Tensor | None = None, | |
| audio_mask: torch.Tensor | None = None, | |
| stop_tokens: torch.Tensor | None = None, | |
| text_stop_tokens: torch.Tensor | None = None, | |
| text_rel_pos: torch.Tensor | None = None, | |
| crossatt_mask: torch.Tensor | None = None, | |
| crossatt_rel_pos: torch.Tensor | None = None, | |
| n_first_layers: int | None = None, | |
| cache: Any | None = None, | |
| ): | |
| input_text_embd = self.text_embd(text_ids) | |
| input_audio_embd = self.audio_embd(audio_inputs[:, :-1]) | |
| if self.stop_token_embd is not None: | |
| if stop_tokens is not None: | |
| stop_tokens_embd = self.stop_token_embd(stop_tokens) | |
| input_audio_embd += stop_tokens_embd[:, :-1] | |
| text_hidden_states = self.text_encoder( | |
| input_text_embd, | |
| mask=text_mask, | |
| text_rel_pos=text_rel_pos, | |
| ) | |
| if self.disabled_crossatt_head_idx is not None and crossatt_mask is not None: | |
| crossatt_mask_list = [] | |
| n_sink, _ = self.sink_tokens.shape | |
| for layer in self.audio_decoder.decoder_layers: | |
| if layer.crossatt is not None: | |
| h = layer.crossatt.heads | |
| crossatt_layer_mask = ( | |
| crossatt_mask.unsqueeze(1).repeat(1, h, 1, 1).clone() | |
| ) | |
| crossatt_layer_mask = torch.nn.functional.pad( | |
| crossatt_layer_mask, | |
| (n_sink, 0), | |
| value=True, | |
| ) | |
| crossatt_mask_list.append(crossatt_layer_mask[:, :, :-1]) | |
| else: | |
| crossatt_mask_list.append(None) | |
| for layer, head in self.disabled_crossatt_head_idx: | |
| crossatt_mask_list[layer][:, head, :, n_sink:] = False | |
| crossatt_mask = crossatt_mask_list | |
| else: | |
| if self.sink_tokens is not None: | |
| n_sink, _ = self.sink_tokens.shape | |
| if crossatt_mask is not None: | |
| crossatt_mask = torch.nn.functional.pad( | |
| crossatt_mask, | |
| (n_sink, 0), | |
| value=True, | |
| ) | |
| crossatt_mask = crossatt_mask[:, :-1] | |
| if self.sink_tokens is not None: | |
| sink_tokens = self.sink_tokens[None, :].repeat( | |
| text_hidden_states.shape[0], 1, 1 | |
| ) | |
| text_hidden_states = torch.cat( | |
| (sink_tokens, text_hidden_states), | |
| dim=1, | |
| ) | |
| if n_first_layers is not None: | |
| pre_logits = self.audio_decoder.forward_first_n_layers( | |
| text_hidden_states, | |
| input_audio_embd, | |
| n_first_layers, | |
| crossatt_mask=crossatt_mask, | |
| ) | |
| else: | |
| pre_logits = self.audio_decoder( | |
| text_hidden_states, | |
| input_audio_embd, | |
| crossatt_mask=crossatt_mask, | |
| cache=cache, | |
| ) | |
| return pre_logits | |
| def generate( | |
| self, | |
| text_ids: torch.LongTensor, | |
| prefix: torch.Tensor, | |
| text_mask: torch.Tensor | None = None, | |
| crossatt_mask: torch.Tensor | None = None, | |
| text_rel_pos: torch.LongTensor | None = None, | |
| teacher_force: torch.Tensor | None = None, | |
| unfold_ref: bool = False, | |
| max_seq_len: int = 200, | |
| device: str = "cuda", | |
| sampling_params: dict | None = None, | |
| stop_threshold: float = 0.5, | |
| cache: Any | None = None, | |
| do_not_stop: bool = False, | |
| ): | |
| if sampling_params is None: | |
| sampling_params = {} | |
| if text_ids.ndim == 1: | |
| text_ids = text_ids.unsqueeze(0) | |
| batch_size = text_ids.shape[0] | |
| input_text_embd = self.text_embd(text_ids) | |
| text_hidden_states = self.text_encoder( | |
| input_text_embd, | |
| text_rel_pos=text_rel_pos, | |
| mask=text_mask, | |
| ) | |
| prefix_embd = self.audio_embd(prefix) | |
| if self.sink_tokens is not None: | |
| sink_tokens = self.sink_tokens[None, :].repeat( | |
| text_hidden_states.shape[0], 1, 1 | |
| ) | |
| text_hidden_states = torch.cat( | |
| (sink_tokens, text_hidden_states), | |
| dim=1, | |
| ) | |
| if crossatt_mask is not None: | |
| n_sink, _ = self.sink_tokens.shape | |
| crossatt_mask = torch.nn.functional.pad( | |
| crossatt_mask, | |
| (n_sink, 0), | |
| value=True, | |
| ) | |
| if cache is None: | |
| cache = self.audio_decoder.init_cache( | |
| max_seq_len + prefix_embd.shape[1], device | |
| ) | |
| stop_status = torch.zeros(batch_size, device=device).bool() | |
| stop_idx = torch.ones(batch_size, device=device).long()*max_seq_len | |
| preds = [] | |
| pre_prediction = self.audio_decoder.prefill( | |
| text_hidden_states, | |
| prefix_embd, | |
| cache=cache, | |
| ) | |
| prediction = self.prediction_head.predict( | |
| pre_prediction[:, [-1]], **sampling_params | |
| ) | |
| prediction_embd = self.audio_embd(prediction) | |
| for i in tqdm(range(max_seq_len)): | |
| pre_prediction = self.audio_decoder.decode_one( | |
| text_hidden_states, | |
| prediction_embd, | |
| cache, | |
| crossatt_mask=crossatt_mask, | |
| ) | |
| if unfold_ref: | |
| pre_prediction, pre_prediction_ref = pre_prediction.chunk(2) | |
| else: | |
| pre_prediction_ref = None | |
| prediction = self.prediction_head.predict(pre_prediction, | |
| pre_prediction_ref=pre_prediction_ref, | |
| **sampling_params,) | |
| prediction_embd = self.audio_embd(prediction) | |
| if unfold_ref: | |
| prediction_embd = prediction_embd.repeat(2, 1, 1) | |
| if teacher_force is not None: | |
| b, n, d = teacher_force.shape | |
| if i < n: | |
| prediction_embd = self.audio_embd(teacher_force[:, [i]]) | |
| preds.append(prediction) | |
| if self.stop_prediction_head is not None: | |
| stop_pred = self.stop_prediction_head(pre_prediction).squeeze(1,2) | |
| stop_signal = stop_pred > stop_threshold | |
| stop_status += stop_signal | |
| stop_idx[stop_signal * stop_idx > i] = i | |
| if stop_status.prod(): | |
| if self.stop_token_embd is not None: | |
| st_embd = self.stop_token_embd( | |
| torch.ones(1, 1, device=device).int() | |
| ) | |
| prediction_embd += st_embd | |
| if not do_not_stop: | |
| break | |
| else: | |
| print(f"STOP: {i}") | |
| full_prediction = torch.cat(preds, dim=1) | |
| full_prediction = [x[:stop_idx[i]][None] for i, x in enumerate(full_prediction.unbind())] | |
| return cache, full_prediction | |
| """ | |
| def generate_with_playhead( | |
| self, | |
| text_ids: torch.LongTensor, | |
| prefix: torch.Tensor, | |
| playhead_model: PlayHead, | |
| selected_heads_idx: list[tuple[int, int]], | |
| text_stop_tokens: torch.LongTensor | None = None, | |
| text_mask: torch.Tensor | None = None, | |
| text_rel_pos: torch.LongTensor | None = None, | |
| teacher_force: torch.Tensor | None = None, | |
| max_seq_len: int = 200, | |
| device: str = "cuda", | |
| sampling_params: dict | None = None, | |
| stop_threshold: float = 0.5, | |
| do_not_stop: bool = False, | |
| width: tuple[int, int] = (5, 1), | |
| abs_pos_start: int = 0, | |
| stop_end_distance_threshold: int = 5, | |
| ): | |
| if sampling_params is None: | |
| sampling_params = {} | |
| if text_ids.ndim == 1: | |
| text_ids = text_ids.unsqueeze(0) | |
| input_text_embd = self.text_embd(text_ids) | |
| if self.text_stop_token_embd is not None: | |
| if text_stop_tokens is not None: | |
| text_stop_tokens_embd = self.text_stop_token_embd(text_stop_tokens) | |
| input_text_embd += text_stop_tokens_embd | |
| text_hidden_states = self.text_encoder( | |
| input_text_embd, | |
| text_rel_pos=text_rel_pos, | |
| mask=text_mask, | |
| ) | |
| prefix_embd = self.audio_embd(prefix) | |
| text_len = text_hidden_states.shape[1] | |
| if self.sink_tokens is not None: | |
| sink_tokens = self.sink_tokens[None, :].repeat( | |
| text_hidden_states.shape[0], 1, 1 | |
| ) | |
| text_hidden_states = torch.cat( | |
| (sink_tokens, text_hidden_states), | |
| dim=1, | |
| ) | |
| cache = self.audio_decoder.init_cache(max_seq_len, device) | |
| preds = [] | |
| pre_prediction = self.audio_decoder.prefill( | |
| text_hidden_states, | |
| prefix_embd, | |
| cache=cache, | |
| ) | |
| text_freqs = None | |
| prediction = self.prediction_head.predict( | |
| pre_prediction[:, [-1]], **sampling_params | |
| ) | |
| prediction_embd = self.audio_embd(prediction) | |
| preds.append(prediction) | |
| playhead_cache = playhead_model.init_cache() | |
| previous_position = torch.zeros(1, device=device) | |
| abs_pos = torch.ones(1, 1, device=device).long() * abs_pos_start | |
| selected_heads_frame = collect_heads(cache, selected_heads_idx, last=False) | |
| selected_heads_frame = selected_heads_frame.sum(1).transpose(-1, -2) | |
| pos_preds = [] | |
| steps = [] | |
| expand_crossatt_mask = [] | |
| for i in tqdm(range(selected_heads_frame.shape[2])): | |
| pred, step = playhead_model.predict( | |
| selected_heads_frame[..., [i]], | |
| cache=playhead_cache, | |
| previous_position=previous_position, | |
| ) | |
| previous_position = pred | |
| abs_pos += step | |
| pos_preds.append(pred) | |
| steps.append(step) | |
| exp_ca_mask = mask_from_abs_pos( | |
| abs_pos, | |
| (text_len // playhead_model.avg_pool_stride) + 1, | |
| playhead_model.avg_pool_stride, | |
| width=width, | |
| ) | |
| exp_ca_mask = torch.nn.functional.pad( | |
| exp_ca_mask, (self.num_sink_tokens, 0), value=True | |
| ).bool()[..., : text_len + self.num_sink_tokens] | |
| expand_crossatt_mask.append(exp_ca_mask) | |
| print("starting at: ", abs_pos.item()) | |
| # pos_pred, step = playhead_model.predict( | |
| # selected_heads_frame, | |
| # cache=playhead_cache, | |
| # previous_position=previous_position, | |
| # ) | |
| # previous_position = pos_pred[:, [-1]] | |
| # abs_pos += step | |
| # exp_ca_mask = mask_from_abs_pos( | |
| # abs_pos, | |
| # (text_len // playhead_model.avg_pool_stride) + 1, | |
| # playhead_model.avg_pool_stride, | |
| # width=width, | |
| # ) | |
| # expand_crossatt_mask.append(exp_ca_mask) | |
| # steps.append(step) | |
| # pos_preds.append(pos_pred) | |
| progress_bar = tqdm(range(max_seq_len)) | |
| for i in progress_bar: | |
| pre_prediction = self.audio_decoder.decode_one( | |
| text_hidden_states, | |
| prediction_embd, | |
| cache, | |
| # text_freqs=text_freqs, | |
| crossatt_mask=exp_ca_mask, | |
| ) | |
| prediction = self.prediction_head.predict(pre_prediction, **sampling_params) | |
| prediction_embd = self.audio_embd(prediction) | |
| if teacher_force is not None: | |
| b, n, d = teacher_force.shape | |
| if i < n: | |
| prediction_embd = self.audio_embd(teacher_force[:, [i]]) | |
| ### PLAYHEAD ======================== | |
| selected_heads_frame = ( | |
| collect_heads(cache, selected_heads_idx).sum(1).transpose(-1, -2) | |
| ) | |
| pos_pred, step = playhead_model.predict( | |
| selected_heads_frame, | |
| cache=playhead_cache, | |
| previous_position=previous_position, | |
| ) | |
| previous_position = pos_pred | |
| abs_pos += step | |
| exp_ca_mask = mask_from_abs_pos( | |
| abs_pos, | |
| (text_len // playhead_model.avg_pool_stride) + 1, | |
| playhead_model.avg_pool_stride, | |
| width=width, | |
| ) | |
| exp_ca_mask = torch.nn.functional.pad( | |
| exp_ca_mask, (self.num_sink_tokens, 0), value=True | |
| ).bool()[..., : text_len + self.num_sink_tokens] | |
| expand_crossatt_mask.append(exp_ca_mask) | |
| steps.append(step) | |
| pos_preds.append(pos_pred) | |
| # ================================= | |
| preds.append(prediction) | |
| if self.stop_prediction_head is not None: | |
| stop_pred = self.stop_prediction_head(pre_prediction) | |
| if stop_pred > stop_threshold: | |
| dist = np.abs( | |
| abs_pos.cpu().item() * playhead_model.avg_pool_stride - text_len | |
| ) | |
| progress_bar.set_postfix( | |
| {"stop": f"pos: {abs_pos.cpu().item()}; dist{dist}"} | |
| ) | |
| if dist < stop_end_distance_threshold and not do_not_stop: | |
| break | |
| progress_bar.set_postfix({"position": abs_pos.cpu().item()}) | |
| full_prediction = torch.cat(preds, dim=1) | |
| expand_crossatt_mask = torch.cat(expand_crossatt_mask, dim=1) | |
| print(expand_crossatt_mask.shape) | |
| return cache, full_prediction, expand_crossatt_mask, steps, pos_preds | |
| """ | |