pardi-speech / tts /playhead.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
raw
history blame
3.97 kB
import torch
from torch import nn
from model.cache_utils import FLACache
from model.config import PlayHeadConfig
from model.prediction_head import CircularHead, LogitsHead
from model.simple_gla import SimpleGLABlock
class PlayHead(nn.Module):
def __init__(self, cfg: PlayHeadConfig):
super().__init__()
self.cycle_len = cfg.cycle_len
self.num_sink_tokens = cfg.num_sink_tokens
self.pos_embedding = nn.Embedding(cfg.num_sink_tokens + cfg.cycle_len, cfg.dim)
self.avg_pool_stride = cfg.avg_pool_stride
self.net = nn.ModuleList(
[
SimpleGLABlock(
dim=cfg.dim,
num_heads=cfg.dim // 128,
layer_idx=i,
expand_k=0.5,
expand_v=1.0,
use_short_conv=True,
ffn_expansion_factor=4,
)
for i in range(cfg.num_layers)
]
)
self.logits_head = (
LogitsHead(cfg.dim, cfg.cycle_len) if cfg.logits_head else None
)
self.circular_head = CircularHead(cfg.dim) if cfg.circular_head else None
def forward(
self,
cross_attention_weights: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor | None = None,
):
B, T, A = cross_attention_weights.shape
# if self.cross_attention_reduction == "sum":
# cross_attention_weights = cross_attention_weights.sum(1)
device = cross_attention_weights.device
pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len
sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len
sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None])
x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd
for block in self.net:
x = block(x)
losses = dict()
if self.logits_head is not None:
losses |= self.logits_head.compute_loss(x, target.long(), mask=mask)
if self.circular_head is not None:
losses |= self.circular_head.compute_loss(x, target, mask=mask)
return losses
def init_cache(self):
return FLACache(num_states=len(self.net))
def predict(
self,
cross_attention_weights: torch.Tensor,
previous_position: torch.Tensor | None = None,
cache: FLACache | None = None,
):
avg_pool_ca = torch.nn.functional.avg_pool1d(
cross_attention_weights[:, self.num_sink_tokens :].transpose(-1, -2),
self.avg_pool_stride,
stride=self.avg_pool_stride,
ceil_mode=True,
).transpose(-1, -2)
sink_ca = cross_attention_weights[:, : self.num_sink_tokens]
cross_attention_weights = torch.cat((sink_ca, avg_pool_ca), dim=1)
B, T, A = cross_attention_weights.shape
device = cross_attention_weights.device
pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len
sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len
sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None])
x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd
for block in self.net:
x = block(x, cache=cache)
if self.logits_head is not None:
logits = self.logits_head(x)
pred_position = torch.argmax(logits, -1)
if previous_position is not None:
current_angle, previous_angle = map(
lambda x: torch.exp(1j * 2 * torch.pi * x / self.cycle_len),
(pred_position, previous_position),
)
diff = current_angle / previous_angle
step = (diff.angle() / (2 * torch.pi / self.cycle_len)).round().long()
return pred_position, step
return pred_position