File size: 7,594 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
from itertools import accumulate
from typing import Callable, List, Optional

import torch
import torch.nn.functional as F

default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def widen_alignment(
    alignment: torch.Tensor, width: int | tuple[int, int], axis: str = "S"
) -> torch.Tensor:
    """
    Widen 1-bands along one axis of an alignment matrix.

    Args:
        alignment: (B, T, S) binary/bool/int tensor
        width: int or (left, right) expansion
               e.g. 2 -> expand ±2
               (1,3) -> expand -1 on the left, +3 on the right
        axis: "S" to widen horizontally (across S),
              "T" to widen vertically (across T)

    Returns:
        (B, T, S) tensor with widened 1-bands along the chosen axis
    """
    assert axis in ("S", "T")
    orig_dtype = alignment.dtype
    dev = alignment.device

    # normalize widths
    if isinstance(width, int):
        left, right = width, width
    else:
        left, right = width
    ksize = left + right + 1
    kernel = torch.ones(1, 1, ksize, device=dev)

    if axis == "S":
        # (B*T, 1, S)
        x = alignment.view(-1, 1, alignment.size(-1)).float()
        x = F.pad(x, (left, right))  # explicit asymmetric padding
        y = F.conv1d(x, kernel)
        y = (y > 0).view_as(alignment)

    else:  # axis == "T"
        # (B*S, 1, T)
        x = (
            alignment.permute(0, 2, 1)
            .contiguous()
            .view(-1, 1, alignment.size(1))
            .float()
        )
        x = F.pad(x, (left, right))
        y = F.conv1d(x, kernel)
        # Back to (B, T, S)
        y = (
            (y > 0)
            .view(alignment.size(0), alignment.size(2), alignment.size(1))
            .permute(0, 2, 1)
        )

    # Cast back to original dtype
    if orig_dtype == torch.bool:
        return y
    elif orig_dtype.is_floating_point:
        return y.to(orig_dtype)
    else:
        return y.to(orig_dtype)


def collect_heads(cache, selected_heads):
    return torch.stack(
        [
            cache[layer]["crossatt_weights"][:, [head], [-1]]
            for layer, head in selected_heads
        ],
        dim=1,
    )


def expand(x, r):
    b, n, d = x.shape
    x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d)
    return x


def path_matrix(positions: torch.Tensor, num_positions: int = None) -> torch.Tensor:
    if num_positions is None:
        num_positions = positions.max().item() + 1
    return F.one_hot(positions, num_classes=num_positions).to(torch.int)


def pad_2d_sequence(seq, padding_value=0):
    max_x, max_y = map(max, zip(*map(lambda x: x.shape, seq)))
    pad = lambda x: torch.nn.functional.pad(
        x,
        (0, max_y - x.shape[1], 0, max_x - x.shape[0]),
        value=padding_value,
    )
    return torch.stack([pad(x) for x in seq])


def audio_to_text_partial_neighbor_mask(
    xlen,
    ylen,
    *,
    past_tokens: int = 0,
    future_tokens: int = 0,
    device=None,
    dtype=torch.bool,
):
    """
    Build an (audio_len, text_len) boolean mask where True = allowed to attend.
    Each audio frame (group g) can attend:
      - all tokens of text group g (aligned word),
      - last `past_tokens` tokens of text group g-1 (previous word),
      - first `future_tokens` tokens of text group g+1 (next word).

    Args:
        xlen (list[int]): token counts per text word (groups), e.g. [2,1,3]
        ylen (list[int]): frame counts per audio word (aligned groups), e.g. [4,2,5]
        past_tokens (int): allow up to this many tokens from end of previous word
        future_tokens (int): allow up to this many tokens from start of next word
        device: torch device
        dtype: output dtype (bool by default)

    Returns:
        mask: (A, T) boolean tensor (A = sum(ylen), T = sum(xlen))
    """
    if len(xlen) != len(ylen):
        raise ValueError(f"len(xlen)={len(xlen)} must equal len(ylen)={len(ylen)}")
    if any(l <= 0 for l in xlen) or any(l <= 0 for l in ylen):
        raise ValueError("All lengths must be positive.")
    if past_tokens < 0 or future_tokens < 0:
        raise ValueError("past_tokens and future_tokens must be >= 0.")

    n = len(xlen)

    # Text-side: group id per token and position within its group
    x_groups = torch.arange(n, device=device).repeat_interleave(
        torch.tensor(xlen, device=device)
    )  # (T,)
    pos_in_group = torch.cat([torch.arange(L, device=device) for L in xlen])  # (T,)
    # tokens from the end (0 for last token, 1 for second-to-last, ...)
    pos_from_end = torch.cat(
        [torch.arange(L - 1, -1, -1, device=device) for L in xlen]
    )  # (T,)

    T = x_groups.numel()

    # Audio-side: group id per frame
    y_groups = torch.arange(n, device=device).repeat_interleave(
        torch.tensor(ylen, device=device)
    )  # (A,)
    A = y_groups.numel()

    # Broadcast to (A, T)
    G_audio = y_groups[:, None]  # (A, 1)
    G_text = x_groups[None, :]  # (1, T)

    # Conditions:
    # 1) aligned word: all tokens
    aligned = G_text == G_audio

    # 2) previous word: last `past_tokens` tokens only
    if past_tokens > 0:
        prev_group = G_text == (G_audio - 1)
        prev_tail = pos_from_end[None, :] < past_tokens
        prev_ok = prev_group & prev_tail
    else:
        prev_ok = torch.zeros((A, T), dtype=torch.bool, device=device)

    # 3) next word: first `future_tokens` tokens only
    if future_tokens > 0:
        next_group = G_text == (G_audio + 1)
        next_head = pos_in_group[None, :] < future_tokens
        next_ok = next_group & next_head
    else:
        next_ok = torch.zeros((A, T), dtype=torch.bool, device=device)

    mask = (aligned | prev_ok | next_ok).to(dtype=dtype)
    return mask


def packmask_2d(xlen: list[int], ylen: list[int], offset: int = 0) -> torch.Tensor:
    _, ybound = map(lambda x: [0] + list(accumulate(x, int.__add__)), (xlen, ylen))
    lb, hb = [], []

    for n, l, h in zip(xlen, ybound[:-1], ybound[1:]):
        lb += [l] * n
        hb += [h] * n

    lb, hb = map(torch.tensor, (lb, hb))
    if offset:
        lb -= offset
        hb += offset

    rge = torch.arange(ybound[-1])

    lm = rge.unsqueeze(0) >= lb.unsqueeze(1)
    hm = rge.unsqueeze(0) < hb.unsqueeze(1)

    return lm * hm


def topk_sampling(seq, k=1, temp=1.0):
    topk = torch.topk(seq, k, dim=-1)
    logits = seq / temp
    mask = logits < topk.values[:, [-1]]
    logits[mask] = -float("Inf")
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)


def delay_rvq(
    code,
    head_token: int = -2,
    tail_token: int = -3,
):
    q, _ = code.shape
    extension = torch.ones((q, q + 1)).tril() * head_token
    extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token
    extension = torch.flip(extension, (1,))
    extended_code = torch.cat((code, extension), axis=1)
    for i in range(q):
        extended_code[i, :] = torch.roll(extended_code[i, :], i + 1)

    return extended_code.long()


def undelay_rvq(extended_code):
    q, _, n = extended_code.shape
    out = []
    for i in range(q):
        out.append(torch.roll(extended_code[i], -(i + 1), dims=1))
    out = torch.stack(out, dim=0)
    return out[:, :, : -(q + 1)]


def sequence_mask(lengths, max_len=None, **kwargs):
    batch_size = lengths.shape[0]
    device = lengths.device
    if max_len is None:
        max_len = torch.max(lengths).item()

    ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
    mask = ids < lengths.unsqueeze(1).expand(-1, max_len)

    return mask