from itertools import accumulate from typing import Callable, List, Optional import torch import torch.nn.functional as F default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def widen_alignment( alignment: torch.Tensor, width: int | tuple[int, int], axis: str = "S" ) -> torch.Tensor: """ Widen 1-bands along one axis of an alignment matrix. Args: alignment: (B, T, S) binary/bool/int tensor width: int or (left, right) expansion e.g. 2 -> expand ±2 (1,3) -> expand -1 on the left, +3 on the right axis: "S" to widen horizontally (across S), "T" to widen vertically (across T) Returns: (B, T, S) tensor with widened 1-bands along the chosen axis """ assert axis in ("S", "T") orig_dtype = alignment.dtype dev = alignment.device # normalize widths if isinstance(width, int): left, right = width, width else: left, right = width ksize = left + right + 1 kernel = torch.ones(1, 1, ksize, device=dev) if axis == "S": # (B*T, 1, S) x = alignment.view(-1, 1, alignment.size(-1)).float() x = F.pad(x, (left, right)) # explicit asymmetric padding y = F.conv1d(x, kernel) y = (y > 0).view_as(alignment) else: # axis == "T" # (B*S, 1, T) x = ( alignment.permute(0, 2, 1) .contiguous() .view(-1, 1, alignment.size(1)) .float() ) x = F.pad(x, (left, right)) y = F.conv1d(x, kernel) # Back to (B, T, S) y = ( (y > 0) .view(alignment.size(0), alignment.size(2), alignment.size(1)) .permute(0, 2, 1) ) # Cast back to original dtype if orig_dtype == torch.bool: return y elif orig_dtype.is_floating_point: return y.to(orig_dtype) else: return y.to(orig_dtype) def collect_heads(cache, selected_heads): return torch.stack( [ cache[layer]["crossatt_weights"][:, [head], [-1]] for layer, head in selected_heads ], dim=1, ) def expand(x, r): b, n, d = x.shape x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d) return x def path_matrix(positions: torch.Tensor, num_positions: int = None) -> torch.Tensor: if num_positions is None: num_positions = positions.max().item() + 1 return F.one_hot(positions, num_classes=num_positions).to(torch.int) def pad_2d_sequence(seq, padding_value=0): max_x, max_y = map(max, zip(*map(lambda x: x.shape, seq))) pad = lambda x: torch.nn.functional.pad( x, (0, max_y - x.shape[1], 0, max_x - x.shape[0]), value=padding_value, ) return torch.stack([pad(x) for x in seq]) def audio_to_text_partial_neighbor_mask( xlen, ylen, *, past_tokens: int = 0, future_tokens: int = 0, device=None, dtype=torch.bool, ): """ Build an (audio_len, text_len) boolean mask where True = allowed to attend. Each audio frame (group g) can attend: - all tokens of text group g (aligned word), - last `past_tokens` tokens of text group g-1 (previous word), - first `future_tokens` tokens of text group g+1 (next word). Args: xlen (list[int]): token counts per text word (groups), e.g. [2,1,3] ylen (list[int]): frame counts per audio word (aligned groups), e.g. [4,2,5] past_tokens (int): allow up to this many tokens from end of previous word future_tokens (int): allow up to this many tokens from start of next word device: torch device dtype: output dtype (bool by default) Returns: mask: (A, T) boolean tensor (A = sum(ylen), T = sum(xlen)) """ if len(xlen) != len(ylen): raise ValueError(f"len(xlen)={len(xlen)} must equal len(ylen)={len(ylen)}") if any(l <= 0 for l in xlen) or any(l <= 0 for l in ylen): raise ValueError("All lengths must be positive.") if past_tokens < 0 or future_tokens < 0: raise ValueError("past_tokens and future_tokens must be >= 0.") n = len(xlen) # Text-side: group id per token and position within its group x_groups = torch.arange(n, device=device).repeat_interleave( torch.tensor(xlen, device=device) ) # (T,) pos_in_group = torch.cat([torch.arange(L, device=device) for L in xlen]) # (T,) # tokens from the end (0 for last token, 1 for second-to-last, ...) pos_from_end = torch.cat( [torch.arange(L - 1, -1, -1, device=device) for L in xlen] ) # (T,) T = x_groups.numel() # Audio-side: group id per frame y_groups = torch.arange(n, device=device).repeat_interleave( torch.tensor(ylen, device=device) ) # (A,) A = y_groups.numel() # Broadcast to (A, T) G_audio = y_groups[:, None] # (A, 1) G_text = x_groups[None, :] # (1, T) # Conditions: # 1) aligned word: all tokens aligned = G_text == G_audio # 2) previous word: last `past_tokens` tokens only if past_tokens > 0: prev_group = G_text == (G_audio - 1) prev_tail = pos_from_end[None, :] < past_tokens prev_ok = prev_group & prev_tail else: prev_ok = torch.zeros((A, T), dtype=torch.bool, device=device) # 3) next word: first `future_tokens` tokens only if future_tokens > 0: next_group = G_text == (G_audio + 1) next_head = pos_in_group[None, :] < future_tokens next_ok = next_group & next_head else: next_ok = torch.zeros((A, T), dtype=torch.bool, device=device) mask = (aligned | prev_ok | next_ok).to(dtype=dtype) return mask def packmask_2d(xlen: list[int], ylen: list[int], offset: int = 0) -> torch.Tensor: _, ybound = map(lambda x: [0] + list(accumulate(x, int.__add__)), (xlen, ylen)) lb, hb = [], [] for n, l, h in zip(xlen, ybound[:-1], ybound[1:]): lb += [l] * n hb += [h] * n lb, hb = map(torch.tensor, (lb, hb)) if offset: lb -= offset hb += offset rge = torch.arange(ybound[-1]) lm = rge.unsqueeze(0) >= lb.unsqueeze(1) hm = rge.unsqueeze(0) < hb.unsqueeze(1) return lm * hm def topk_sampling(seq, k=1, temp=1.0): topk = torch.topk(seq, k, dim=-1) logits = seq / temp mask = logits < topk.values[:, [-1]] logits[mask] = -float("Inf") probs = torch.softmax(logits, dim=-1) return torch.multinomial(probs, num_samples=1) def delay_rvq( code, head_token: int = -2, tail_token: int = -3, ): q, _ = code.shape extension = torch.ones((q, q + 1)).tril() * head_token extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token extension = torch.flip(extension, (1,)) extended_code = torch.cat((code, extension), axis=1) for i in range(q): extended_code[i, :] = torch.roll(extended_code[i, :], i + 1) return extended_code.long() def undelay_rvq(extended_code): q, _, n = extended_code.shape out = [] for i in range(q): out.append(torch.roll(extended_code[i], -(i + 1), dims=1)) out = torch.stack(out, dim=0) return out[:, :, : -(q + 1)] def sequence_mask(lengths, max_len=None, **kwargs): batch_size = lengths.shape[0] device = lengths.device if max_len is None: max_len = torch.max(lengths).item() ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) mask = ids < lengths.unsqueeze(1).expand(-1, max_len) return mask