Spaces:
Running
on
Zero
Running
on
Zero
| from itertools import accumulate | |
| from typing import Callable, List, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def widen_alignment( | |
| alignment: torch.Tensor, width: int | tuple[int, int], axis: str = "S" | |
| ) -> torch.Tensor: | |
| """ | |
| Widen 1-bands along one axis of an alignment matrix. | |
| Args: | |
| alignment: (B, T, S) binary/bool/int tensor | |
| width: int or (left, right) expansion | |
| e.g. 2 -> expand ±2 | |
| (1,3) -> expand -1 on the left, +3 on the right | |
| axis: "S" to widen horizontally (across S), | |
| "T" to widen vertically (across T) | |
| Returns: | |
| (B, T, S) tensor with widened 1-bands along the chosen axis | |
| """ | |
| assert axis in ("S", "T") | |
| orig_dtype = alignment.dtype | |
| dev = alignment.device | |
| # normalize widths | |
| if isinstance(width, int): | |
| left, right = width, width | |
| else: | |
| left, right = width | |
| ksize = left + right + 1 | |
| kernel = torch.ones(1, 1, ksize, device=dev) | |
| if axis == "S": | |
| # (B*T, 1, S) | |
| x = alignment.view(-1, 1, alignment.size(-1)).float() | |
| x = F.pad(x, (left, right)) # explicit asymmetric padding | |
| y = F.conv1d(x, kernel) | |
| y = (y > 0).view_as(alignment) | |
| else: # axis == "T" | |
| # (B*S, 1, T) | |
| x = ( | |
| alignment.permute(0, 2, 1) | |
| .contiguous() | |
| .view(-1, 1, alignment.size(1)) | |
| .float() | |
| ) | |
| x = F.pad(x, (left, right)) | |
| y = F.conv1d(x, kernel) | |
| # Back to (B, T, S) | |
| y = ( | |
| (y > 0) | |
| .view(alignment.size(0), alignment.size(2), alignment.size(1)) | |
| .permute(0, 2, 1) | |
| ) | |
| # Cast back to original dtype | |
| if orig_dtype == torch.bool: | |
| return y | |
| elif orig_dtype.is_floating_point: | |
| return y.to(orig_dtype) | |
| else: | |
| return y.to(orig_dtype) | |
| def collect_heads(cache, selected_heads): | |
| return torch.stack( | |
| [ | |
| cache[layer]["crossatt_weights"][:, [head], [-1]] | |
| for layer, head in selected_heads | |
| ], | |
| dim=1, | |
| ) | |
| def expand(x, r): | |
| b, n, d = x.shape | |
| x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d) | |
| return x | |
| def path_matrix(positions: torch.Tensor, num_positions: int = None) -> torch.Tensor: | |
| if num_positions is None: | |
| num_positions = positions.max().item() + 1 | |
| return F.one_hot(positions, num_classes=num_positions).to(torch.int) | |
| def pad_2d_sequence(seq, padding_value=0): | |
| max_x, max_y = map(max, zip(*map(lambda x: x.shape, seq))) | |
| pad = lambda x: torch.nn.functional.pad( | |
| x, | |
| (0, max_y - x.shape[1], 0, max_x - x.shape[0]), | |
| value=padding_value, | |
| ) | |
| return torch.stack([pad(x) for x in seq]) | |
| def audio_to_text_partial_neighbor_mask( | |
| xlen, | |
| ylen, | |
| *, | |
| past_tokens: int = 0, | |
| future_tokens: int = 0, | |
| device=None, | |
| dtype=torch.bool, | |
| ): | |
| """ | |
| Build an (audio_len, text_len) boolean mask where True = allowed to attend. | |
| Each audio frame (group g) can attend: | |
| - all tokens of text group g (aligned word), | |
| - last `past_tokens` tokens of text group g-1 (previous word), | |
| - first `future_tokens` tokens of text group g+1 (next word). | |
| Args: | |
| xlen (list[int]): token counts per text word (groups), e.g. [2,1,3] | |
| ylen (list[int]): frame counts per audio word (aligned groups), e.g. [4,2,5] | |
| past_tokens (int): allow up to this many tokens from end of previous word | |
| future_tokens (int): allow up to this many tokens from start of next word | |
| device: torch device | |
| dtype: output dtype (bool by default) | |
| Returns: | |
| mask: (A, T) boolean tensor (A = sum(ylen), T = sum(xlen)) | |
| """ | |
| if len(xlen) != len(ylen): | |
| raise ValueError(f"len(xlen)={len(xlen)} must equal len(ylen)={len(ylen)}") | |
| if any(l <= 0 for l in xlen) or any(l <= 0 for l in ylen): | |
| raise ValueError("All lengths must be positive.") | |
| if past_tokens < 0 or future_tokens < 0: | |
| raise ValueError("past_tokens and future_tokens must be >= 0.") | |
| n = len(xlen) | |
| # Text-side: group id per token and position within its group | |
| x_groups = torch.arange(n, device=device).repeat_interleave( | |
| torch.tensor(xlen, device=device) | |
| ) # (T,) | |
| pos_in_group = torch.cat([torch.arange(L, device=device) for L in xlen]) # (T,) | |
| # tokens from the end (0 for last token, 1 for second-to-last, ...) | |
| pos_from_end = torch.cat( | |
| [torch.arange(L - 1, -1, -1, device=device) for L in xlen] | |
| ) # (T,) | |
| T = x_groups.numel() | |
| # Audio-side: group id per frame | |
| y_groups = torch.arange(n, device=device).repeat_interleave( | |
| torch.tensor(ylen, device=device) | |
| ) # (A,) | |
| A = y_groups.numel() | |
| # Broadcast to (A, T) | |
| G_audio = y_groups[:, None] # (A, 1) | |
| G_text = x_groups[None, :] # (1, T) | |
| # Conditions: | |
| # 1) aligned word: all tokens | |
| aligned = G_text == G_audio | |
| # 2) previous word: last `past_tokens` tokens only | |
| if past_tokens > 0: | |
| prev_group = G_text == (G_audio - 1) | |
| prev_tail = pos_from_end[None, :] < past_tokens | |
| prev_ok = prev_group & prev_tail | |
| else: | |
| prev_ok = torch.zeros((A, T), dtype=torch.bool, device=device) | |
| # 3) next word: first `future_tokens` tokens only | |
| if future_tokens > 0: | |
| next_group = G_text == (G_audio + 1) | |
| next_head = pos_in_group[None, :] < future_tokens | |
| next_ok = next_group & next_head | |
| else: | |
| next_ok = torch.zeros((A, T), dtype=torch.bool, device=device) | |
| mask = (aligned | prev_ok | next_ok).to(dtype=dtype) | |
| return mask | |
| def packmask_2d(xlen: list[int], ylen: list[int], offset: int = 0) -> torch.Tensor: | |
| _, ybound = map(lambda x: [0] + list(accumulate(x, int.__add__)), (xlen, ylen)) | |
| lb, hb = [], [] | |
| for n, l, h in zip(xlen, ybound[:-1], ybound[1:]): | |
| lb += [l] * n | |
| hb += [h] * n | |
| lb, hb = map(torch.tensor, (lb, hb)) | |
| if offset: | |
| lb -= offset | |
| hb += offset | |
| rge = torch.arange(ybound[-1]) | |
| lm = rge.unsqueeze(0) >= lb.unsqueeze(1) | |
| hm = rge.unsqueeze(0) < hb.unsqueeze(1) | |
| return lm * hm | |
| def topk_sampling(seq, k=1, temp=1.0): | |
| topk = torch.topk(seq, k, dim=-1) | |
| logits = seq / temp | |
| mask = logits < topk.values[:, [-1]] | |
| logits[mask] = -float("Inf") | |
| probs = torch.softmax(logits, dim=-1) | |
| return torch.multinomial(probs, num_samples=1) | |
| def delay_rvq( | |
| code, | |
| head_token: int = -2, | |
| tail_token: int = -3, | |
| ): | |
| q, _ = code.shape | |
| extension = torch.ones((q, q + 1)).tril() * head_token | |
| extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token | |
| extension = torch.flip(extension, (1,)) | |
| extended_code = torch.cat((code, extension), axis=1) | |
| for i in range(q): | |
| extended_code[i, :] = torch.roll(extended_code[i, :], i + 1) | |
| return extended_code.long() | |
| def undelay_rvq(extended_code): | |
| q, _, n = extended_code.shape | |
| out = [] | |
| for i in range(q): | |
| out.append(torch.roll(extended_code[i], -(i + 1), dims=1)) | |
| out = torch.stack(out, dim=0) | |
| return out[:, :, : -(q + 1)] | |
| def sequence_mask(lengths, max_len=None, **kwargs): | |
| batch_size = lengths.shape[0] | |
| device = lengths.device | |
| if max_len is None: | |
| max_len = torch.max(lengths).item() | |
| ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) | |
| mask = ids < lengths.unsqueeze(1).expand(-1, max_len) | |
| return mask | |