import json from pathlib import Path from typing import Any import numpy as np import torch from safetensors.torch import load_file from torch import nn from tqdm import tqdm from tts.model.config import TTSConfig from tts.model.prediction_head import (ContinuousHead, LogitsHead, StopPredictionHead, VelocityHead) from tts.model.registry import DECODER_REGISTRY, ENCODER_REGISTRY from tts.tools import path_matrix, widen_alignment def collect_heads(cache, selected_heads, last=True): if last: return torch.cat( [ cache[layer]["crossatt_weights"][:, [head], -1] for layer, head in selected_heads ], dim=1, )[:, :, None] else: return torch.cat( [ cache[layer]["crossatt_weights"][:, [head]] for layer, head in selected_heads ], dim=1, ) def mask_from_abs_pos(abs_pos, text_len, expand_factor, width=(5, 1)): exp_ca_mask = path_matrix(abs_pos, text_len) exp_ca_mask = widen_alignment(exp_ca_mask, width=width, axis="S") exp_ca_mask = expand(exp_ca_mask, expand_factor) return exp_ca_mask def expand(x, r): b, n, d = x.shape x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d) return x class ARTTSModel(nn.Module): def __init__(self, cfg: TTSConfig): super().__init__() self.text_embd = nn.Embedding(cfg.text_vocab_size, cfg.dim) if cfg.audio_input_type == "discrete": self.audio_embd = nn.Embedding(cfg.audio_vocab_size, cfg.dim) self.prediction_head = LogitsHead(cfg.decoder_cfg.dim, cfg.audio_vocab_size) elif cfg.audio_input_type == "continuous" and cfg.continuous_diffusion: self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim) self.prediction_head = VelocityHead( cfg.decoder_cfg.dim, cfg.audio_embed_size, cfg.diffusion_head_num_layers, ) elif cfg.audio_input_type == "continuous": self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim) self.prediction_head = ContinuousHead( cfg.decoder_cfg.dim, cfg.audio_embed_size, ) self.text_encoder = ENCODER_REGISTRY[cfg.encoder_cfg.name](cfg.encoder_cfg) self.audio_decoder = DECODER_REGISTRY[cfg.decoder_cfg.name](cfg.decoder_cfg) self.stop_token_embd = None self.stop_prediction_head = None if cfg.stop_prediction_head: if cfg.stop_token_embd: self.stop_token_embd = nn.Embedding(2, cfg.dim, padding_idx=0) self.stop_prediction_head = StopPredictionHead(cfg.dim) if cfg.num_sink_tokens > 0: self.sink_tokens = nn.Parameter( torch.randn(cfg.num_sink_tokens, cfg.dim) * 0.02, requires_grad=True ) else: self.sink_tokens = None self.disabled_crossatt_head_idx = cfg.disabled_crossatt_head_idx @property def num_sink_tokens(self): if self.sink_tokens is None: return 0 else: n_sink, _ = self.sink_tokens.shape return n_sink @classmethod def instantiate_from_config(cls, config): for k in config.keys(): if k == "decoder_cfg": config[k] = DECODER_REGISTRY[config[k]["name"]].config(**config[k]) if k == "encoder_cfg": config[k] = ENCODER_REGISTRY[config[k]["name"]].config(**config[k]) config = TTSConfig(**config) return ARTTSModel(config), config @classmethod def from_pretrained_local( cls, path: str, config_filename: str = "config.json", model_filename: str = "model.st", device: str = "cpu", ): with open(Path(path) / config_filename, "r") as f: config = json.load(f) model, config = cls.instantiate_from_config(config) state_dict = load_file(Path(path) / model_filename, device=device) model.load_state_dict(state_dict) return model def _get_query(self, x: torch.Tensor, *args): input_audio_embd = self.audio_embd(x) return self.audio_decoder._get_query(input_audio_embd, *args) def forward( self, text_ids: torch.LongTensor, audio_inputs: torch.Tensor, text_mask: torch.Tensor | None = None, audio_mask: torch.Tensor | None = None, stop_tokens: torch.Tensor | None = None, text_stop_tokens: torch.Tensor | None = None, text_rel_pos: torch.Tensor | None = None, crossatt_mask: torch.Tensor | None = None, crossatt_rel_pos: torch.Tensor | None = None, n_first_layers: int | None = None, cache: Any | None = None, ): input_text_embd = self.text_embd(text_ids) input_audio_embd = self.audio_embd(audio_inputs[:, :-1]) if self.stop_token_embd is not None: if stop_tokens is not None: stop_tokens_embd = self.stop_token_embd(stop_tokens) input_audio_embd += stop_tokens_embd[:, :-1] text_hidden_states = self.text_encoder( input_text_embd, mask=text_mask, text_rel_pos=text_rel_pos, ) if self.disabled_crossatt_head_idx is not None and crossatt_mask is not None: crossatt_mask_list = [] n_sink, _ = self.sink_tokens.shape for layer in self.audio_decoder.decoder_layers: if layer.crossatt is not None: h = layer.crossatt.heads crossatt_layer_mask = ( crossatt_mask.unsqueeze(1).repeat(1, h, 1, 1).clone() ) crossatt_layer_mask = torch.nn.functional.pad( crossatt_layer_mask, (n_sink, 0), value=True, ) crossatt_mask_list.append(crossatt_layer_mask[:, :, :-1]) else: crossatt_mask_list.append(None) for layer, head in self.disabled_crossatt_head_idx: crossatt_mask_list[layer][:, head, :, n_sink:] = False crossatt_mask = crossatt_mask_list else: if self.sink_tokens is not None: n_sink, _ = self.sink_tokens.shape if crossatt_mask is not None: crossatt_mask = torch.nn.functional.pad( crossatt_mask, (n_sink, 0), value=True, ) crossatt_mask = crossatt_mask[:, :-1] if self.sink_tokens is not None: sink_tokens = self.sink_tokens[None, :].repeat( text_hidden_states.shape[0], 1, 1 ) text_hidden_states = torch.cat( (sink_tokens, text_hidden_states), dim=1, ) if n_first_layers is not None: pre_logits = self.audio_decoder.forward_first_n_layers( text_hidden_states, input_audio_embd, n_first_layers, crossatt_mask=crossatt_mask, ) else: pre_logits = self.audio_decoder( text_hidden_states, input_audio_embd, crossatt_mask=crossatt_mask, cache=cache, ) return pre_logits def generate( self, text_ids: torch.LongTensor, prefix: torch.Tensor, text_mask: torch.Tensor | None = None, crossatt_mask: torch.Tensor | None = None, text_rel_pos: torch.LongTensor | None = None, teacher_force: torch.Tensor | None = None, unfold_ref: bool = False, max_seq_len: int = 200, device: str = "cuda", sampling_params: dict | None = None, stop_threshold: float = 0.5, cache: Any | None = None, do_not_stop: bool = False, ): if sampling_params is None: sampling_params = {} if text_ids.ndim == 1: text_ids = text_ids.unsqueeze(0) batch_size = text_ids.shape[0] input_text_embd = self.text_embd(text_ids) text_hidden_states = self.text_encoder( input_text_embd, text_rel_pos=text_rel_pos, mask=text_mask, ) prefix_embd = self.audio_embd(prefix) if self.sink_tokens is not None: sink_tokens = self.sink_tokens[None, :].repeat( text_hidden_states.shape[0], 1, 1 ) text_hidden_states = torch.cat( (sink_tokens, text_hidden_states), dim=1, ) if crossatt_mask is not None: n_sink, _ = self.sink_tokens.shape crossatt_mask = torch.nn.functional.pad( crossatt_mask, (n_sink, 0), value=True, ) if cache is None: cache = self.audio_decoder.init_cache( max_seq_len + prefix_embd.shape[1], device ) stop_status = torch.zeros(batch_size, device=device).bool() stop_idx = torch.ones(batch_size, device=device).long()*max_seq_len preds = [] pre_prediction = self.audio_decoder.prefill( text_hidden_states, prefix_embd, cache=cache, ) prediction = self.prediction_head.predict( pre_prediction[:, [-1]], **sampling_params ) prediction_embd = self.audio_embd(prediction) for i in tqdm(range(max_seq_len)): pre_prediction = self.audio_decoder.decode_one( text_hidden_states, prediction_embd, cache, crossatt_mask=crossatt_mask, ) if unfold_ref: pre_prediction, pre_prediction_ref = pre_prediction.chunk(2) else: pre_prediction_ref = None prediction = self.prediction_head.predict(pre_prediction, pre_prediction_ref=pre_prediction_ref, **sampling_params,) prediction_embd = self.audio_embd(prediction) if unfold_ref: prediction_embd = prediction_embd.repeat(2, 1, 1) if teacher_force is not None: b, n, d = teacher_force.shape if i < n: prediction_embd = self.audio_embd(teacher_force[:, [i]]) preds.append(prediction) if self.stop_prediction_head is not None: stop_pred = self.stop_prediction_head(pre_prediction).squeeze(1,2) stop_signal = stop_pred > stop_threshold stop_status += stop_signal stop_idx[stop_signal * stop_idx > i] = i if stop_status.prod(): if self.stop_token_embd is not None: st_embd = self.stop_token_embd( torch.ones(1, 1, device=device).int() ) prediction_embd += st_embd if not do_not_stop: break else: print(f"STOP: {i}") full_prediction = torch.cat(preds, dim=1) full_prediction = [x[:stop_idx[i]][None] for i, x in enumerate(full_prediction.unbind())] return cache, full_prediction """ def generate_with_playhead( self, text_ids: torch.LongTensor, prefix: torch.Tensor, playhead_model: PlayHead, selected_heads_idx: list[tuple[int, int]], text_stop_tokens: torch.LongTensor | None = None, text_mask: torch.Tensor | None = None, text_rel_pos: torch.LongTensor | None = None, teacher_force: torch.Tensor | None = None, max_seq_len: int = 200, device: str = "cuda", sampling_params: dict | None = None, stop_threshold: float = 0.5, do_not_stop: bool = False, width: tuple[int, int] = (5, 1), abs_pos_start: int = 0, stop_end_distance_threshold: int = 5, ): if sampling_params is None: sampling_params = {} if text_ids.ndim == 1: text_ids = text_ids.unsqueeze(0) input_text_embd = self.text_embd(text_ids) if self.text_stop_token_embd is not None: if text_stop_tokens is not None: text_stop_tokens_embd = self.text_stop_token_embd(text_stop_tokens) input_text_embd += text_stop_tokens_embd text_hidden_states = self.text_encoder( input_text_embd, text_rel_pos=text_rel_pos, mask=text_mask, ) prefix_embd = self.audio_embd(prefix) text_len = text_hidden_states.shape[1] if self.sink_tokens is not None: sink_tokens = self.sink_tokens[None, :].repeat( text_hidden_states.shape[0], 1, 1 ) text_hidden_states = torch.cat( (sink_tokens, text_hidden_states), dim=1, ) cache = self.audio_decoder.init_cache(max_seq_len, device) preds = [] pre_prediction = self.audio_decoder.prefill( text_hidden_states, prefix_embd, cache=cache, ) text_freqs = None prediction = self.prediction_head.predict( pre_prediction[:, [-1]], **sampling_params ) prediction_embd = self.audio_embd(prediction) preds.append(prediction) playhead_cache = playhead_model.init_cache() previous_position = torch.zeros(1, device=device) abs_pos = torch.ones(1, 1, device=device).long() * abs_pos_start selected_heads_frame = collect_heads(cache, selected_heads_idx, last=False) selected_heads_frame = selected_heads_frame.sum(1).transpose(-1, -2) pos_preds = [] steps = [] expand_crossatt_mask = [] for i in tqdm(range(selected_heads_frame.shape[2])): pred, step = playhead_model.predict( selected_heads_frame[..., [i]], cache=playhead_cache, previous_position=previous_position, ) previous_position = pred abs_pos += step pos_preds.append(pred) steps.append(step) exp_ca_mask = mask_from_abs_pos( abs_pos, (text_len // playhead_model.avg_pool_stride) + 1, playhead_model.avg_pool_stride, width=width, ) exp_ca_mask = torch.nn.functional.pad( exp_ca_mask, (self.num_sink_tokens, 0), value=True ).bool()[..., : text_len + self.num_sink_tokens] expand_crossatt_mask.append(exp_ca_mask) print("starting at: ", abs_pos.item()) # pos_pred, step = playhead_model.predict( # selected_heads_frame, # cache=playhead_cache, # previous_position=previous_position, # ) # previous_position = pos_pred[:, [-1]] # abs_pos += step # exp_ca_mask = mask_from_abs_pos( # abs_pos, # (text_len // playhead_model.avg_pool_stride) + 1, # playhead_model.avg_pool_stride, # width=width, # ) # expand_crossatt_mask.append(exp_ca_mask) # steps.append(step) # pos_preds.append(pos_pred) progress_bar = tqdm(range(max_seq_len)) for i in progress_bar: pre_prediction = self.audio_decoder.decode_one( text_hidden_states, prediction_embd, cache, # text_freqs=text_freqs, crossatt_mask=exp_ca_mask, ) prediction = self.prediction_head.predict(pre_prediction, **sampling_params) prediction_embd = self.audio_embd(prediction) if teacher_force is not None: b, n, d = teacher_force.shape if i < n: prediction_embd = self.audio_embd(teacher_force[:, [i]]) ### PLAYHEAD ======================== selected_heads_frame = ( collect_heads(cache, selected_heads_idx).sum(1).transpose(-1, -2) ) pos_pred, step = playhead_model.predict( selected_heads_frame, cache=playhead_cache, previous_position=previous_position, ) previous_position = pos_pred abs_pos += step exp_ca_mask = mask_from_abs_pos( abs_pos, (text_len // playhead_model.avg_pool_stride) + 1, playhead_model.avg_pool_stride, width=width, ) exp_ca_mask = torch.nn.functional.pad( exp_ca_mask, (self.num_sink_tokens, 0), value=True ).bool()[..., : text_len + self.num_sink_tokens] expand_crossatt_mask.append(exp_ca_mask) steps.append(step) pos_preds.append(pos_pred) # ================================= preds.append(prediction) if self.stop_prediction_head is not None: stop_pred = self.stop_prediction_head(pre_prediction) if stop_pred > stop_threshold: dist = np.abs( abs_pos.cpu().item() * playhead_model.avg_pool_stride - text_len ) progress_bar.set_postfix( {"stop": f"pos: {abs_pos.cpu().item()}; dist{dist}"} ) if dist < stop_end_distance_threshold and not do_not_stop: break progress_bar.set_postfix({"position": abs_pos.cpu().item()}) full_prediction = torch.cat(preds, dim=1) expand_crossatt_mask = torch.cat(expand_crossatt_mask, dim=1) print(expand_crossatt_mask.shape) return cache, full_prediction, expand_crossatt_mask, steps, pos_preds """