pardi-speech / tts /tools.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
raw
history blame
7.59 kB
from itertools import accumulate
from typing import Callable, List, Optional
import torch
import torch.nn.functional as F
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def widen_alignment(
alignment: torch.Tensor, width: int | tuple[int, int], axis: str = "S"
) -> torch.Tensor:
"""
Widen 1-bands along one axis of an alignment matrix.
Args:
alignment: (B, T, S) binary/bool/int tensor
width: int or (left, right) expansion
e.g. 2 -> expand ±2
(1,3) -> expand -1 on the left, +3 on the right
axis: "S" to widen horizontally (across S),
"T" to widen vertically (across T)
Returns:
(B, T, S) tensor with widened 1-bands along the chosen axis
"""
assert axis in ("S", "T")
orig_dtype = alignment.dtype
dev = alignment.device
# normalize widths
if isinstance(width, int):
left, right = width, width
else:
left, right = width
ksize = left + right + 1
kernel = torch.ones(1, 1, ksize, device=dev)
if axis == "S":
# (B*T, 1, S)
x = alignment.view(-1, 1, alignment.size(-1)).float()
x = F.pad(x, (left, right)) # explicit asymmetric padding
y = F.conv1d(x, kernel)
y = (y > 0).view_as(alignment)
else: # axis == "T"
# (B*S, 1, T)
x = (
alignment.permute(0, 2, 1)
.contiguous()
.view(-1, 1, alignment.size(1))
.float()
)
x = F.pad(x, (left, right))
y = F.conv1d(x, kernel)
# Back to (B, T, S)
y = (
(y > 0)
.view(alignment.size(0), alignment.size(2), alignment.size(1))
.permute(0, 2, 1)
)
# Cast back to original dtype
if orig_dtype == torch.bool:
return y
elif orig_dtype.is_floating_point:
return y.to(orig_dtype)
else:
return y.to(orig_dtype)
def collect_heads(cache, selected_heads):
return torch.stack(
[
cache[layer]["crossatt_weights"][:, [head], [-1]]
for layer, head in selected_heads
],
dim=1,
)
def expand(x, r):
b, n, d = x.shape
x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d)
return x
def path_matrix(positions: torch.Tensor, num_positions: int = None) -> torch.Tensor:
if num_positions is None:
num_positions = positions.max().item() + 1
return F.one_hot(positions, num_classes=num_positions).to(torch.int)
def pad_2d_sequence(seq, padding_value=0):
max_x, max_y = map(max, zip(*map(lambda x: x.shape, seq)))
pad = lambda x: torch.nn.functional.pad(
x,
(0, max_y - x.shape[1], 0, max_x - x.shape[0]),
value=padding_value,
)
return torch.stack([pad(x) for x in seq])
def audio_to_text_partial_neighbor_mask(
xlen,
ylen,
*,
past_tokens: int = 0,
future_tokens: int = 0,
device=None,
dtype=torch.bool,
):
"""
Build an (audio_len, text_len) boolean mask where True = allowed to attend.
Each audio frame (group g) can attend:
- all tokens of text group g (aligned word),
- last `past_tokens` tokens of text group g-1 (previous word),
- first `future_tokens` tokens of text group g+1 (next word).
Args:
xlen (list[int]): token counts per text word (groups), e.g. [2,1,3]
ylen (list[int]): frame counts per audio word (aligned groups), e.g. [4,2,5]
past_tokens (int): allow up to this many tokens from end of previous word
future_tokens (int): allow up to this many tokens from start of next word
device: torch device
dtype: output dtype (bool by default)
Returns:
mask: (A, T) boolean tensor (A = sum(ylen), T = sum(xlen))
"""
if len(xlen) != len(ylen):
raise ValueError(f"len(xlen)={len(xlen)} must equal len(ylen)={len(ylen)}")
if any(l <= 0 for l in xlen) or any(l <= 0 for l in ylen):
raise ValueError("All lengths must be positive.")
if past_tokens < 0 or future_tokens < 0:
raise ValueError("past_tokens and future_tokens must be >= 0.")
n = len(xlen)
# Text-side: group id per token and position within its group
x_groups = torch.arange(n, device=device).repeat_interleave(
torch.tensor(xlen, device=device)
) # (T,)
pos_in_group = torch.cat([torch.arange(L, device=device) for L in xlen]) # (T,)
# tokens from the end (0 for last token, 1 for second-to-last, ...)
pos_from_end = torch.cat(
[torch.arange(L - 1, -1, -1, device=device) for L in xlen]
) # (T,)
T = x_groups.numel()
# Audio-side: group id per frame
y_groups = torch.arange(n, device=device).repeat_interleave(
torch.tensor(ylen, device=device)
) # (A,)
A = y_groups.numel()
# Broadcast to (A, T)
G_audio = y_groups[:, None] # (A, 1)
G_text = x_groups[None, :] # (1, T)
# Conditions:
# 1) aligned word: all tokens
aligned = G_text == G_audio
# 2) previous word: last `past_tokens` tokens only
if past_tokens > 0:
prev_group = G_text == (G_audio - 1)
prev_tail = pos_from_end[None, :] < past_tokens
prev_ok = prev_group & prev_tail
else:
prev_ok = torch.zeros((A, T), dtype=torch.bool, device=device)
# 3) next word: first `future_tokens` tokens only
if future_tokens > 0:
next_group = G_text == (G_audio + 1)
next_head = pos_in_group[None, :] < future_tokens
next_ok = next_group & next_head
else:
next_ok = torch.zeros((A, T), dtype=torch.bool, device=device)
mask = (aligned | prev_ok | next_ok).to(dtype=dtype)
return mask
def packmask_2d(xlen: list[int], ylen: list[int], offset: int = 0) -> torch.Tensor:
_, ybound = map(lambda x: [0] + list(accumulate(x, int.__add__)), (xlen, ylen))
lb, hb = [], []
for n, l, h in zip(xlen, ybound[:-1], ybound[1:]):
lb += [l] * n
hb += [h] * n
lb, hb = map(torch.tensor, (lb, hb))
if offset:
lb -= offset
hb += offset
rge = torch.arange(ybound[-1])
lm = rge.unsqueeze(0) >= lb.unsqueeze(1)
hm = rge.unsqueeze(0) < hb.unsqueeze(1)
return lm * hm
def topk_sampling(seq, k=1, temp=1.0):
topk = torch.topk(seq, k, dim=-1)
logits = seq / temp
mask = logits < topk.values[:, [-1]]
logits[mask] = -float("Inf")
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
def delay_rvq(
code,
head_token: int = -2,
tail_token: int = -3,
):
q, _ = code.shape
extension = torch.ones((q, q + 1)).tril() * head_token
extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token
extension = torch.flip(extension, (1,))
extended_code = torch.cat((code, extension), axis=1)
for i in range(q):
extended_code[i, :] = torch.roll(extended_code[i, :], i + 1)
return extended_code.long()
def undelay_rvq(extended_code):
q, _, n = extended_code.shape
out = []
for i in range(q):
out.append(torch.roll(extended_code[i], -(i + 1), dims=1))
out = torch.stack(out, dim=0)
return out[:, :, : -(q + 1)]
def sequence_mask(lengths, max_len=None, **kwargs):
batch_size = lengths.shape[0]
device = lengths.device
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
return mask