Lakoc commited on
Commit
64c2cbc
·
verified ·
1 Parent(s): 8f90ab1

Upload DiCoWForConditionalGeneration

Browse files
Files changed (15) hide show
  1. FDDT.py +75 -0
  2. README.md +199 -0
  3. SCBs.py +411 -0
  4. coattention.py +120 -0
  5. config.json +86 -0
  6. config.py +103 -0
  7. contrastive_loss.py +190 -0
  8. decoding.py +397 -0
  9. encoder.py +328 -0
  10. generation.py +1808 -0
  11. generation_config.json +12 -0
  12. layers.py +99 -0
  13. model.safetensors +3 -0
  14. modeling_dicow.py +450 -0
  15. utils.py +96 -0
FDDT.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .layers import CustomDiagonalLinear, CustomLinear
7
+ from .SCBs import SpeakerCommunicationBlock
8
+
9
+
10
+ class FDDT(nn.Module):
11
+ def __init__(self, config, d_model, non_target_rate=0.01, is_diagonal=False, bias_only=False, use_silence=True,
12
+ use_target=True, use_overlap=True, use_non_target=True, use_interaction=False):
13
+ super().__init__()
14
+ if use_target:
15
+ self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
16
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
17
+ d_model,
18
+ bias=True,
19
+ init_eye_val=1.0))
20
+ if use_non_target:
21
+ self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
22
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
23
+ d_model, d_model, bias=True, init_eye_val=non_target_rate))
24
+ if use_overlap:
25
+ self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
26
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
27
+ d_model,
28
+ bias=True,
29
+ init_eye_val=1.0))
30
+ if use_silence:
31
+ self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
32
+ CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
33
+ d_model, d_model, bias=True, init_eye_val=non_target_rate))
34
+
35
+ if use_interaction:
36
+ self.scb = SpeakerCommunicationBlock(config)
37
+
38
+ self.use_silence = use_silence
39
+ self.use_target = use_target
40
+ self.use_overlap = use_overlap
41
+ self.use_non_target = use_non_target
42
+ self.use_interaction = use_interaction
43
+ self.bias_only = bias_only
44
+
45
+ @staticmethod
46
+ def mask_out_non_interaction_signal(hidden_states, mask):
47
+ mask = torch.round(mask).bool()
48
+ masked_hidden_states = hidden_states * mask
49
+ return masked_hidden_states
50
+
51
+ def forward(self, hidden_states, stno_mask):
52
+ stno_mask = stno_mask.to(hidden_states.device)[..., None]
53
+ if self.bias_only:
54
+ if self.use_silence:
55
+ hidden_states += stno_mask[:, 0, ...] * self.silence_linear
56
+ if self.use_target:
57
+ hidden_states += stno_mask[:, 1, ...] * self.target_linear
58
+ if self.use_non_target:
59
+ hidden_states += stno_mask[:, 2, ...] * self.non_target_linear
60
+ if self.use_overlap:
61
+ hidden_states += stno_mask[:, 3, ...] * self.overlap_linear
62
+ else:
63
+ orig_hidden_states = hidden_states
64
+ hidden_states = (self.silence_linear(
65
+ orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \
66
+ (self.target_linear(
67
+ orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \
68
+ (self.non_target_linear(
69
+ orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2,
70
+ :] + \
71
+ (self.overlap_linear(
72
+ orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :]
73
+ if self.use_interaction:
74
+ hidden_states = self.scb(hidden_states)
75
+ return hidden_states
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
SCBs.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import WhisperConfig
4
+ from transformers.activations import ACT2FN
5
+ from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
6
+ import torch.nn.functional as F
7
+ from .coattention import CoAttention
8
+ from .layers import CustomLinear, CustomDiagonalLinear, Gate, CustomLinearInitialized
9
+
10
+ class LowRankApproxSelectFirst(nn.Module):
11
+ def __init__(self, d_in, d_out, rank):
12
+ super().__init__()
13
+ self.d_in = d_in
14
+ self.d_out = d_out
15
+ self.rank = rank
16
+ self.proj_in = nn.Linear(d_in, rank)
17
+ self.proj_out = nn.Linear(rank, d_out)
18
+
19
+ def forward(self, x):
20
+ return self.proj_out(self.proj_in(x))
21
+
22
+ def _init_weights(self):
23
+ # Create low-rank approximation of the identity projection from first d_out of input
24
+ eye = torch.eye(self.d_out, self.d_in) # (d_out x d_in)
25
+
26
+ # Low-rank SVD of eye matrix
27
+ U, S, Vh = torch.linalg.svd(eye, full_matrices=False) # U: (d_out x d_out), Vh: (d_in x d_in)
28
+
29
+ U_k = U[:, :self.rank] # (d_out x rank)
30
+ S_k = S[:self.rank] # (rank,)
31
+ V_k = Vh[:self.rank, :] # (rank x d_in)
32
+
33
+ A = V_k # (rank x d_in)
34
+ B = U_k @ torch.diag(S_k) # (d_out x rank)
35
+
36
+ # Set weights
37
+ self.proj_in.weight.data.copy_(A)
38
+ self.proj_in.bias.data.zero_()
39
+ self.proj_out.weight.data.copy_(B)
40
+ self.proj_out.bias.data.zero_()
41
+
42
+
43
+
44
+ class TACBlock(nn.Module):
45
+ def __init__(self, config: WhisperConfig, d_int_factor: float = 1, num_speakers=2):
46
+ super().__init__()
47
+ d = config.d_model
48
+ d_prime = int(d * d_int_factor)
49
+ self.num_speakers = num_speakers
50
+ self.proj_in_1 = nn.Linear(d, d_prime, bias=True)
51
+ self.proj_in_2 = nn.Linear(d, d_prime, bias=True)
52
+ self.proj_int = nn.Linear(d_prime, d_prime,bias=True)
53
+ self.proj_out_1 = nn.Linear(d+d_prime, d,bias=True)
54
+ self.proj_out_2 = nn.Linear(d+d_prime, d,bias=True)
55
+ self.activation_fn = ACT2FN[config.activation_function]
56
+ self.norms = nn.ModuleList([nn.LayerNorm(d) for _ in range(self.num_speakers)])
57
+ self.gate = Gate(self.num_speakers, 0.05)
58
+
59
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
60
+ # hidden_states: (B, self.num_speakers, T, F)
61
+
62
+ x_proj = torch.stack([self.activation_fn(self.proj_in_1(hidden_states[:,0])), self.activation_fn(self.proj_in_2(hidden_states[:, 1]))], dim=1) # (B, 2, T, d')
63
+ x_mean = x_proj.mean(dim=1, keepdim=True) # (B, 1, T, d')
64
+ z = self.activation_fn(self.proj_int(x_mean)) # (B, 1, T, d')
65
+
66
+ z_expand = z.expand(-1, self.num_speakers, -1, -1) # (B, self.num_speakers, T, d')
67
+ x_cat = torch.cat([hidden_states, z_expand], dim=-1) # (B, self.num_speakers, T, d + d')
68
+ x_out = torch.stack([self.norms[0](self.proj_out_1(x_cat[:, 0])), self.norms[1](self.proj_out_2(x_cat[:, 1]))], dim=1) # (B, self.num_speakers, T, d)
69
+ return hidden_states + self.gate(x_out, dim=1)
70
+
71
+
72
+ class CrossAttentionBlock(nn.Module):
73
+ def __init__(self, config: WhisperConfig):
74
+ super().__init__()
75
+ self.embed_dim = config.d_model
76
+
77
+ self.num_speakers = getattr(config, "mt_num_speakers", 2)
78
+ if self.num_speakers != 2:
79
+ raise ValueError("CrossAttentionBlock supports only 2 speakers.")
80
+
81
+ # Separate attention block per speaker
82
+ self.attn_blocks = nn.ModuleList([
83
+ WHISPER_ATTENTION_CLASSES[config._attn_implementation](
84
+ embed_dim=self.embed_dim,
85
+ num_heads=config.encoder_attention_heads,
86
+ dropout=config.attention_dropout,
87
+ config=config,
88
+ )
89
+ for _ in range(self.num_speakers)
90
+ ])
91
+
92
+ self.norms = nn.ModuleList([nn.LayerNorm(self.embed_dim) for _ in range(self.num_speakers)])
93
+ self.gate = Gate(self.num_speakers, 0.01)
94
+
95
+ def forward(self, hidden_states):
96
+ # hidden_states: (B, 2, T, F)
97
+ outputs = []
98
+ for s in range(self.num_speakers):
99
+ q = hidden_states[:, s] # (B, T, F)
100
+ other_s = 1 - s
101
+ kv = hidden_states[:, other_s] # (B, T, F)
102
+
103
+ attn_out, _, _ = self.attn_blocks[s](hidden_states=q, key_value_states=kv) # (B, T, F)
104
+ outputs.append(self.norms[s](attn_out[:, None, :, :]))
105
+ outputs = torch.concat(outputs, dim=1)
106
+ outputs_modulated = self.gate(outputs, dim=1) + hidden_states
107
+ return outputs_modulated
108
+
109
+ # class CrossAttentionEnrollBlock(nn.Module):
110
+ # def __init__(self, config, layer_norm_eps: float = 1e-5):
111
+ # super().__init__()
112
+ # self.embed_dim = config.d_model
113
+ # self.ffn_dim = config.encoder_ffn_dim
114
+ #
115
+ # self.cross_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
116
+ # embed_dim=self.embed_dim,
117
+ # num_heads=config.encoder_attention_heads,
118
+ # dropout=config.attention_dropout,
119
+ # config=config,
120
+ # )
121
+ #
122
+ # # Layer normalization (pre-norm style)
123
+ # self.norm_attn = nn.LayerNorm(self.embed_dim, eps=layer_norm_eps)
124
+ # self.norm_ffn = nn.LayerNorm(self.embed_dim * 2, eps=layer_norm_eps)
125
+ #
126
+ # # Feed-forward network
127
+ # self.ffn = nn.Sequential(
128
+ # nn.Linear(self.embed_dim * 2, self.ffn_dim),
129
+ # ACT2FN[config.activation_function],
130
+ # nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1),
131
+ # nn.Linear(self.ffn_dim, self.embed_dim),
132
+ # nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1)
133
+ # )
134
+ #
135
+ # def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
136
+ # """
137
+ # Args:
138
+ # hidden_states: (B, 2, T, F) - batch, channels, time, features
139
+ # Returns:
140
+ # Updated hidden states of same shape
141
+ # """
142
+ # q_channel = hidden_states[:, 0] # (B, T, F)
143
+ # kv_channel = hidden_states[:, 1] # (B, T, F)
144
+ #
145
+ # # Cross-attention with residual connection
146
+ # q_normed = self.norm_attn(q_channel)
147
+ # attn_output = self.cross_attn(
148
+ # hidden_states=q_normed,
149
+ # key_value_states=kv_channel,
150
+ # output_attentions=False
151
+ # )[0]
152
+ #
153
+ # q_after_attn = torch.cat([attn_output, q_normed], dim=-1)
154
+ #
155
+ # # Feed-forward with residual connection
156
+ # q_normed_ffn = self.norm_ffn(q_after_attn)
157
+ #
158
+ # ffn_output = self.ffn(q_normed_ffn)
159
+ # updated_q = q_after_attn + ffn_output
160
+ #
161
+ # # Return stacked result (only query channel is updated)
162
+ # return torch.stack([updated_q, kv_channel], dim=1)
163
+
164
+ def first_init_fun(module):
165
+ # Zero out all weights initially
166
+ # module.weight.data.zero_()
167
+ torch.nn.init.xavier_uniform_(module.weight, gain=0.1)
168
+
169
+ # Create identity mapping for second half of input (q_normed part)
170
+ # Input: [cross_attn_output, q_normed] -> map q_normed to first embed_dim outputs
171
+ module.weight.data[:module.weight.shape[1] // 2, module.weight.shape[1] // 2:] += torch.eye(module.weight.shape[1] // 2)
172
+ # module.weight.data[:module.weight.shape[1]//2, module.weight.shape[1]//2:] = torch.eye(module.weight.shape[1]//2)
173
+
174
+ # Zero bias
175
+ module.bias.data.zero_()
176
+
177
+ def second_init_fun(module):
178
+ # module.weight.data.zero_()
179
+ torch.nn.init.xavier_uniform_(module.weight, gain=0.1)
180
+
181
+ # Create identity mapping from first embed_dim inputs to output
182
+ module.weight.data[:, :module.weight.shape[0]] += torch.eye(module.weight.shape[0])
183
+
184
+ # Zero bias for second linear
185
+ module.bias.data.zero_()
186
+
187
+ # Cross attention block that can easily learn to ignore cross attention initially
188
+ class CrossAttentionEnrollBlockNew(nn.Module):
189
+ def __init__(self, config, layer_norm_eps: float = 1e-5):
190
+ super().__init__()
191
+ self.embed_dim = config.d_model
192
+ self.ffn_dim = config.encoder_ffn_dim
193
+
194
+ self.cross_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
195
+ embed_dim=self.embed_dim,
196
+ num_heads=config.encoder_attention_heads,
197
+ dropout=config.attention_dropout,
198
+ config=config,
199
+ )
200
+
201
+ # Layer normalization (pre-norm style)
202
+ # self.norm_attn = nn.LayerNorm(self.embed_dim, eps=layer_norm_eps)
203
+ self.cross_gate = nn.Parameter(torch.zeros(1))
204
+ # Feed-forward network that maps concat space back to single channel
205
+ self.ffn = nn.Sequential(
206
+ CustomLinearInitialized(self.embed_dim * 2, self.ffn_dim, init_fun=first_init_fun),
207
+ ACT2FN[config.activation_function],
208
+ nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1),
209
+ CustomLinearInitialized(self.ffn_dim, self.embed_dim, init_fun=second_init_fun),
210
+ nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1)
211
+ )
212
+
213
+
214
+
215
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
216
+ """
217
+ Args:
218
+ hidden_states: (B, 2, T, F) - batch, channels, time, features
219
+ Returns:
220
+ Updated hidden states of same shape
221
+ """
222
+ q_channel = hidden_states[:, 0] # (B, T, F)
223
+ kv_channel = hidden_states[:, 1] # (B, T, F)
224
+
225
+ # Cross-attention
226
+ attn_output = self.cross_attn(
227
+ hidden_states=q_channel,
228
+ key_value_states=kv_channel,
229
+ output_attentions=False
230
+ )[0]
231
+
232
+ # Concatenate attention output with original normalized query
233
+ q_concat = torch.cat([attn_output, q_channel], dim=-1) # (B, T, 2*F)
234
+
235
+ # Feed-forward processing (no normalization to preserve initialization)
236
+ # updated_q = self.ffn(q_concat) # (B, T, F)
237
+ updated_q = q_channel + torch.tanh(self.cross_gate) * self.ffn(q_concat)
238
+
239
+ # Return stacked result (only query channel is updated)
240
+ return torch.stack([updated_q, kv_channel], dim=1)
241
+
242
+ class CrossAttentionEnrollBlock(nn.Module):
243
+ def __init__(self, config: WhisperConfig):
244
+ super().__init__()
245
+ self.embed_dim = config.d_model
246
+
247
+ # Separate attention block per speaker
248
+ self.attn_block = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
249
+ embed_dim=self.embed_dim,
250
+ num_heads=config.encoder_attention_heads,
251
+ dropout=config.attention_dropout,
252
+ config=config,
253
+ )
254
+
255
+ self.norm = nn.LayerNorm(self.embed_dim)
256
+ self.gate = Gate(1, 0.1)
257
+
258
+ def forward(self, hidden_states):
259
+ q = hidden_states[:, 0] # (B, T, F)
260
+ kv = hidden_states[:, 1] # (B, T, F)
261
+ attn_out, _, _ = self.attn_block(hidden_states=q, key_value_states=kv) # (B, T, F)
262
+ out = self.norm(attn_out)
263
+
264
+ # Create updated first channel
265
+ updated_q = self.gate(out[:, None, :, :], dim=1)[:, 0] + q
266
+
267
+ # Concatenate along the channel dimension
268
+ result = torch.stack([updated_q, kv], dim=1)
269
+ return result
270
+
271
+
272
+ class CompetitiveCrossAttentionBlock(nn.Module):
273
+ def __init__(self, config):
274
+ super().__init__()
275
+ self.embed_dim = config.d_model
276
+ self.num_heads = config.encoder_attention_heads
277
+ self.head_dim = self.embed_dim // self.num_heads
278
+ assert (
279
+ self.head_dim * self.num_heads == self.embed_dim
280
+ ), "embed_dim must be divisible by num_heads"
281
+
282
+ self.num_speakers = getattr(config, "mt_num_speakers", 2)
283
+ if self.num_speakers != 2:
284
+ raise ValueError("CompetitiveCrossAttentionBlock supports only 2 speakers.")
285
+
286
+ # Separate projections for Q, K, V
287
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
288
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
289
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
290
+
291
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
292
+
293
+ self.norms = nn.ModuleList([nn.LayerNorm(self.embed_dim) for _ in range(self.num_speakers)])
294
+ self.eps = 1e-6
295
+ self.gate = Gate(self.num_speakers, 0.01)
296
+
297
+ def _shape(self, tensor, seq_len, batch_size):
298
+ # reshape into (B, num_heads, T, head_dim)
299
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
300
+
301
+ def forward(self, hidden_states):
302
+ # hidden_states: (B, 2, T, F)
303
+ B, _, T, _ = hidden_states.shape
304
+
305
+ h1, h2 = hidden_states[:, 0], hidden_states[:, 1] # (B, T, F)
306
+
307
+ # Project Q,K,V
308
+ Q1 = self.q_proj(h1) # (B, T, F)
309
+ K2 = self.k_proj(h2)
310
+ V2 = self.v_proj(h2)
311
+
312
+ Q2 = self.q_proj(h2)
313
+ K1 = self.k_proj(h1)
314
+ V1 = self.v_proj(h1)
315
+
316
+ # Reshape for multi-head attention
317
+ Q1 = self._shape(Q1, T, B) # (B, heads, T, head_dim)
318
+ K2 = self._shape(K2, T, B)
319
+ V2 = self._shape(V2, T, B)
320
+
321
+ Q2 = self._shape(Q2, T, B)
322
+ K1 = self._shape(K1, T, B)
323
+ V1 = self._shape(V1, T, B)
324
+
325
+ # Scaled dot-product attention logits
326
+ scale = 1 / (self.head_dim ** 0.5)
327
+ L_1to2 = torch.matmul(Q1, K2.transpose(-1, -2)) * scale # (B, heads, T, T)
328
+ L_2to1 = torch.matmul(Q2, K1.transpose(-1, -2)) * scale # (B, heads, T, T)
329
+
330
+ # Softmax over last dim (keys)
331
+ S_1to2 = F.softmax(L_1to2, dim=-1)
332
+ S_2to1 = F.softmax(L_2to1, dim=-1)
333
+
334
+ # Competitive normalization (soft exclusivity)
335
+ M_joint = S_1to2 + S_2to1 + self.eps
336
+ A_1to2 = S_1to2 / M_joint
337
+ A_2to1 = S_2to1 / M_joint
338
+
339
+ # Weighted sum of values
340
+ H1_attn = torch.matmul(A_1to2, V2) # (B, heads, T, head_dim)
341
+ H2_attn = torch.matmul(A_2to1, V1)
342
+
343
+ # Concatenate heads back
344
+ H1_attn = H1_attn.transpose(1, 2).contiguous().view(B, T, self.embed_dim) # (B, T, F)
345
+ H2_attn = H2_attn.transpose(1, 2).contiguous().view(B, T, self.embed_dim)
346
+
347
+ # Output projection
348
+ H1_attn = self.norms[0](self.out_proj(H1_attn))
349
+ H2_attn = self.norms[1](self.out_proj(H2_attn))
350
+
351
+ # Residuals
352
+ out = hidden_states + self.gate(torch.concat([H1_attn[:, None, :, :], H2_attn[:, None, :, :]], dim=1), dim=1)
353
+
354
+ return out # (B, 2, T, F)
355
+
356
+
357
+ class CoAttentionWrapper(nn.Module):
358
+ def __init__(self, config, num_speakers=2):
359
+ super().__init__()
360
+ self.coa = CoAttention(embed_dim=config.d_model, single_dim=config.d_model//2, multi_dim=config.d_model // 4, n_heads=config.encoder_attention_heads, attn_dropout=config.attention_dropout)
361
+ self.gate = Gate(num_speakers, 0.01)
362
+
363
+ def forward(self, coa_input: torch.Tensor) -> torch.Tensor:
364
+ # hidden_states: (B, 2, T, F)
365
+ hidden_states = coa_input.permute(-2, 0, 1, -1)
366
+ hidden_states = self.coa(hidden_states)
367
+ out = coa_input + self.gate(hidden_states.permute(1, 2, 0, -1), dim=1)
368
+ return out
369
+
370
+
371
+ class SpeakerCommunicationBlock(nn.Module):
372
+ def __init__(self, config):
373
+ super().__init__()
374
+ self.num_speakers = getattr(config, "mt_num_speakers", 2)
375
+ self.embed_dim = config.d_model
376
+ self.scb_method = config.scb_method
377
+ self.config = config
378
+
379
+ if self.scb_method == "tac":
380
+ self.method = TACBlock(config)
381
+ elif self.scb_method == "cross_attention":
382
+ self.method = CrossAttentionBlock(config)
383
+ elif self.scb_method == "cross_attention_enroll":
384
+ self.method = CrossAttentionEnrollBlock(config)
385
+ elif self.scb_method == "cross_attention_enroll_new":
386
+ self.method = CrossAttentionEnrollBlockNew(config)
387
+ elif self.scb_method == "competitive_cross_attention":
388
+ self.method = CompetitiveCrossAttentionBlock(config)
389
+ elif self.scb_method == "co_attention":
390
+ self.method = CoAttentionWrapper(config)
391
+ elif self.scb_method == "identity":
392
+ self.method = (nn.Parameter(torch.zeros(self.embed_dim)) if config.fddt_bias_only else (
393
+ CustomDiagonalLinear(self.embed_dim, bias=True, init_eye_val=1.0) if config.fddt_is_diagonal else CustomLinear(
394
+ self.embed_dim, self.embed_dim, bias=True, init_eye_val=1.0)))
395
+ else:
396
+ raise ValueError(f"Unsupported scb_method: {self.scb_method}")
397
+
398
+ def forward(self, x):
399
+ # x: (B, T, F)
400
+ B, T, F = x.shape
401
+ S = self.num_speakers
402
+
403
+ # Reshape to (B//S, S, T, F)
404
+ x_reshaped = x.view(B//S, S, T, F)
405
+
406
+ # Call the selected method
407
+ out = self.method(x_reshaped)
408
+
409
+ # Reshape back (B, T, F)
410
+ out_merged = out.view(B, T, F)
411
+ return out_merged
coattention.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class MultiHeadCoAttention(nn.Module):
5
+ def __init__(self, multi_dim, single_dim, num_heads):
6
+ assert multi_dim % num_heads == 0, 'multi_dim must be divisible by num_heads'
7
+ assert single_dim % num_heads == 0, 'single_dim must be divisible by num_heads'
8
+ super().__init__()
9
+ self.q_proj = nn.Linear(single_dim, single_dim)
10
+ self.k_proj = nn.Linear(single_dim, single_dim)
11
+ self.multi_v_proj = nn.Linear(multi_dim, multi_dim) # D'
12
+ self.single_v_proj = nn.Linear(single_dim, single_dim) # D
13
+
14
+ self.multi_out_proj = nn.Linear(multi_dim, multi_dim) # D'
15
+ self.single_out_proj = nn.Linear(single_dim, single_dim) # D
16
+
17
+ self.multi_dim = multi_dim
18
+ self.single_dim = single_dim
19
+ self.num_heads = num_heads
20
+
21
+ def forward(self, query, key, multi_value, single_value):
22
+ # q, k, multi_v: (T,B,ch,D')
23
+ # single_v: (T,B,1,D)
24
+ query = torch.transpose(query, 0, 1) # (B,T,ch,D')...[32, 150, 4, 64]
25
+ key = torch.transpose(key, 0, 1) # (B,T,ch,D')...[32, 150, 4, 64]
26
+ multi_value = torch.permute(multi_value, (1, 2, 0, 3)) # (B,ch,T,D')...[32, 4, 150, 64]
27
+ single_value = torch.permute(single_value, (1, 2, 0, 3)) # (B,1,T,D)...[32, 1, 150, 256]
28
+ ###########
29
+
30
+ q = torch.split(self.q_proj(query), self.single_dim // self.num_heads, dim=-1) # seq: (B,T,ch,D'/h)
31
+ q = torch.stack(q, dim=1) # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8]
32
+
33
+ k = torch.split(self.k_proj(key), self.single_dim // self.num_heads, dim=-1) # seq: (B,T,ch,D'/h)
34
+ k = torch.stack(k, dim=1) # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8]
35
+
36
+ multi_v = torch.split(self.multi_v_proj(multi_value), self.multi_dim // self.num_heads,
37
+ dim=-1) # seq: (B,ch,T,D'/h)
38
+ multi_v = torch.stack(multi_v, dim=1) # (B, h, ch, T, D'/h)...[32, 8, 4, 150, 8]
39
+
40
+ single_v = torch.split(self.single_v_proj(single_value), self.single_dim // self.num_heads,
41
+ dim=-1) # seq: (B,1,T,D/h)
42
+ single_v = torch.stack(single_v, dim=1) # seq: (B,h,1,T,D/h)...[32, 32, 1, 150, 8]
43
+
44
+ q = q.view(*q.shape[:-2], -1) # (B, h, T, ch*D/h)
45
+ k = k.view(*k.shape[:-2], -1) # (B, h, T, ch*D/h)
46
+ normalizer = torch.sqrt(torch.Tensor([float(q.shape[-1])]).to(q.device))
47
+
48
+ sim_mat = torch.matmul(q, torch.transpose(k, -2, -1)) / normalizer # (B, h, T, T)
49
+ att_mat = torch.unsqueeze(nn.functional.softmax(sim_mat, dim=-1), 2) # (B, h, 1, T, T)
50
+
51
+ # co-attention
52
+ multi_result = torch.matmul(att_mat, multi_v) # (B, h, ch, T, D'/h)
53
+ single_result = torch.matmul(att_mat, single_v) # (B, h, 1, T, D/h)
54
+
55
+ multi_result = torch.permute(multi_result, (3, 0, 2, 1, 4)) # (T, B, ch, h, D'/h)
56
+ single_result = torch.permute(single_result, (3, 0, 2, 1, 4)) # (T, B, 1, h, D/h)
57
+ multi_result = torch.reshape(multi_result, multi_result.shape[:-2] + (-1,)) # (T, B, ch, D')
58
+ single_result = torch.reshape(single_result, single_result.shape[:-2] + (-1,)) # (T, B, 1, D)
59
+
60
+ multi_result = self.multi_out_proj(multi_result)
61
+ single_result = self.single_out_proj(single_result)
62
+ return multi_result, single_result
63
+
64
+
65
+ class CoAttention(nn.Module):
66
+ def __init__(self, embed_dim=768, single_dim=256, multi_dim=64, n_heads=8, attn_dropout=0.,
67
+ init_mult=1e-2): # , pre_norm=True):
68
+ super().__init__()
69
+ self.init_mult = init_mult
70
+
71
+ self.in_single_proj = nn.Linear(embed_dim, single_dim) # single_dim == D
72
+ self.in_single_ln = nn.LayerNorm(single_dim)
73
+
74
+ self.in_multi_proj = nn.Linear(embed_dim, multi_dim) # multi_dim == D'
75
+ self.in_multi_ln = nn.LayerNorm(multi_dim)
76
+
77
+ self.mca = MultiHeadCoAttention(multi_dim, single_dim, n_heads)
78
+ self.mca_multi_out_ln = nn.LayerNorm(multi_dim)
79
+ self.mca_single_out_ln = nn.LayerNorm(single_dim)
80
+
81
+ # default MHA input: (seq, batch, feature)
82
+ self.cross_frame_mha = nn.MultiheadAttention(single_dim, n_heads, dropout=attn_dropout, bias=True, kdim=None,
83
+ vdim=None)
84
+ self.mha_ln = nn.LayerNorm(single_dim)
85
+
86
+ self.cat_proj = nn.Linear(single_dim + multi_dim, embed_dim)
87
+
88
+ self.miso = False
89
+
90
+ def scale_weights(self):
91
+ self.cat_proj.bias.data *= 0.
92
+ self.cat_proj.weight.data *= self.init_mult
93
+
94
+ def forward(self, x):
95
+ # x: (T,B,ch,F); (150, 32, 4, 768)
96
+ frames, B, chans, feat_dim = x.shape
97
+
98
+ single_x = torch.mean(x,dim=2) # (T,B,F)
99
+ single_x = self.in_single_ln(self.in_single_proj(single_x)).unsqueeze(dim=-2) # (T,B,1,D)
100
+
101
+ multi_x = self.in_multi_ln(self.in_multi_proj(x)) # (T,B,ch,D')
102
+
103
+ # MCA
104
+ multi_mca, single_mca = self.mca(single_x, single_x, multi_x, single_x) # (T,B,ch,D'), (T,B,ch,D)
105
+ single_x = single_x + single_mca
106
+ multi_x = multi_x + multi_mca
107
+ multi_x = self.mca_multi_out_ln(multi_x) # (T,B,ch,D')
108
+ single_x = torch.squeeze(self.mca_single_out_ln(single_x), -2) # (T,B,D)
109
+
110
+ # MHA
111
+ single_mha, _ = self.cross_frame_mha(single_x, single_x, single_x, need_weights=False) # (T, B, D)
112
+ single_x = self.mha_ln(single_mha + single_x)
113
+
114
+ # join representations
115
+ single_x = single_x.unsqueeze(-2) # (T,B,1,D)
116
+ single_x_tile = torch.tile(single_x, (1, 1, chans, 1)) # (T,B,ch,D)
117
+ cat_x = torch.cat([single_x_tile, multi_x], dim=-1) # (T,B,ch,D+D')
118
+ out = self.cat_proj(cat_x) # (T,B,ch,F)
119
+
120
+ return out
config.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/mnt/matylda5/ipoloka/ASRU_models/se_dicow",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "additional_layer": false,
6
+ "additional_self_attention_layer": true,
7
+ "apply_fddt_to_n_layers": -1,
8
+ "apply_spec_augment": false,
9
+ "architectures": [
10
+ "DiCoWForConditionalGeneration"
11
+ ],
12
+ "attend_to_enrollment": false,
13
+ "attention_dropout": 0.0,
14
+ "auto_map": {
15
+ "AutoConfig": "config.DiCoWConfig",
16
+ "AutoModelForSpeechSeq2Seq": "modeling_dicow.DiCoWForConditionalGeneration"
17
+ },
18
+ "begin_suppress_tokens": [
19
+ 220,
20
+ 50256
21
+ ],
22
+ "blank_token_id": null,
23
+ "bos_token_id": 50257,
24
+ "classifier_proj_size": 256,
25
+ "contrastive_loss_weight": 0,
26
+ "ctc_loss_reduction": "mean",
27
+ "ctc_weight": 0.3,
28
+ "ctc_zero_infinity": false,
29
+ "d_model": 1280,
30
+ "decoder_attention_heads": 20,
31
+ "decoder_ffn_dim": 5120,
32
+ "decoder_layerdrop": 0.0,
33
+ "decoder_layers": 4,
34
+ "decoder_start_token_id": 50258,
35
+ "dropout": 0.0,
36
+ "encoder_attention_heads": 20,
37
+ "encoder_ffn_dim": 5120,
38
+ "encoder_layerdrop": 0.0,
39
+ "encoder_layers": 32,
40
+ "eos_token_id": 50257,
41
+ "fddt_bias_only": false,
42
+ "fddt_init": "disparagement",
43
+ "fddt_is_diagonal": true,
44
+ "fddt_use_non_target": true,
45
+ "fddt_use_overlap": true,
46
+ "fddt_use_silence": true,
47
+ "fddt_use_target": true,
48
+ "final_dropout": 0.0,
49
+ "forced_decoder_ids": null,
50
+ "init_std": 0.02,
51
+ "is_encoder_decoder": true,
52
+ "is_mt": true,
53
+ "mask_feature_length": 10,
54
+ "mask_feature_min_masks": 0,
55
+ "mask_feature_prob": 0.0,
56
+ "mask_time_length": 10,
57
+ "mask_time_min_masks": 2,
58
+ "mask_time_prob": 0.05,
59
+ "max_source_positions": 1500,
60
+ "max_target_positions": 448,
61
+ "median_filter_width": 7,
62
+ "model_type": "DiCoW",
63
+ "mt_num_speakers": 2,
64
+ "n_soft_prompts": 16,
65
+ "non_target_fddt_value": 0.5,
66
+ "num_hidden_layers": 32,
67
+ "num_mel_bins": 128,
68
+ "num_speakers": null,
69
+ "pad_token_id": 50257,
70
+ "remove_timestamps_from_ctc": true,
71
+ "scale_embedding": false,
72
+ "scb_layers": 8,
73
+ "scb_method": "cross_attention_enroll_new",
74
+ "sid_loss_weight": 0,
75
+ "spk_embedding_extraction_layer": -1,
76
+ "sub_sample": true,
77
+ "torch_dtype": "float32",
78
+ "transformers_version": "4.42.0",
79
+ "use_cache": true,
80
+ "use_enrollment_network": false,
81
+ "use_fddt": true,
82
+ "use_initial_fddt": true,
83
+ "use_weighted_layer_sum": false,
84
+ "uses_enrollments": true,
85
+ "vocab_size": 51866
86
+ }
config.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from transformers import WhisperConfig
6
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput, Seq2SeqModelOutput
7
+
8
+
9
+ @dataclass
10
+ class Seq2SeqLMOutputLosses(Seq2SeqLMOutput):
11
+ enc_loss: Optional[torch.FloatTensor] = None
12
+ dec_loss: Optional[torch.FloatTensor] = None
13
+ encoder_logits: Optional[torch.FloatTensor] = None
14
+
15
+
16
+ @dataclass
17
+ class BaseModelOutputLogit(BaseModelOutput):
18
+ logits: Optional[torch.FloatTensor] = None
19
+
20
+
21
+ @dataclass
22
+ class Seq2SeqModelOutputLogit(Seq2SeqModelOutput):
23
+ encoder_logits: Optional[torch.FloatTensor] = None
24
+
25
+
26
+ class DiCoWConfig(WhisperConfig):
27
+ """This is a modified version of the `WhisperEncoder` model from the `transformers` library.
28
+ The model has been modified to support CTC loss computation in the forward pass."""
29
+ model_type = "DiCoW"
30
+
31
+ def __init__(
32
+ self,
33
+ ctc_loss_reduction: str = "mean",
34
+ final_dropout: float = 0.0,
35
+ ctc_zero_infinity: bool = False,
36
+ ctc_weight: float = 0.0,
37
+ blank_token_id: Optional[int] = None,
38
+ additional_layer: bool = False,
39
+ additional_self_attention_layer: bool = False,
40
+ sub_sample: bool = False,
41
+ use_fddt: bool = True,
42
+ fddt_is_diagonal: bool = True,
43
+ fddt_bias_only: bool = False,
44
+ fddt_use_silence: bool = True,
45
+ fddt_use_target: bool = True,
46
+ fddt_use_overlap: bool = True,
47
+ fddt_use_non_target: bool = True,
48
+ remove_timestamps_from_ctc: bool = False,
49
+ apply_fddt_to_n_layers: int = -1,
50
+ fddt_init: str = 'non-disturbing', # random, non-disturbing, dispargement
51
+ n_soft_prompts: int = 16,
52
+ mt_num_speakers: int = 1,
53
+ is_mt: bool = False,
54
+ non_target_fddt_value: float = 0.0,
55
+ use_initial_fddt: bool = False,
56
+ scb_method: str = None,
57
+ scb_layers: int = -1,
58
+ contrastive_loss_weight: float = 0.0,
59
+ use_enrollment_network: bool = False,
60
+ spk_embedding_extraction_layer: int = -1,
61
+ num_speakers: int = -1,
62
+ sid_loss_weight: float = 0.0,
63
+ attend_to_enrollment: bool = False,
64
+ uses_enrollments: bool = False,
65
+ **kwargs,
66
+ ):
67
+ super().__init__(**kwargs)
68
+ self.ctc_loss_reduction = ctc_loss_reduction
69
+ self.final_dropout = final_dropout
70
+ self.ctc_zero_infinity = ctc_zero_infinity
71
+ self.ctc_weight = ctc_weight
72
+ self.blank_token_id = blank_token_id
73
+ self.additional_layer = additional_layer
74
+ self.additional_self_attention_layer = additional_self_attention_layer
75
+ self.sub_sample = sub_sample
76
+ self.use_fddt = use_fddt
77
+ self.fddt_is_diagonal = fddt_is_diagonal
78
+ self.fddt_bias_only = fddt_bias_only
79
+ self.fddt_use_silence = fddt_use_silence
80
+ self.fddt_use_target = fddt_use_target
81
+ self.fddt_use_overlap = fddt_use_overlap
82
+ self.fddt_use_non_target = fddt_use_non_target
83
+ self.remove_timestamps_from_ctc = remove_timestamps_from_ctc
84
+ self.apply_fddt_to_n_layers = apply_fddt_to_n_layers
85
+ self.fddt_init = fddt_init
86
+ self.n_soft_prompts = n_soft_prompts
87
+ self.mt_num_speakers = mt_num_speakers
88
+ self.non_target_fddt_value = non_target_fddt_value
89
+ self.use_initial_fddt = use_initial_fddt
90
+ self.scb_method = scb_method
91
+ self.scb_layers = scb_layers
92
+ self.contrastive_loss_weight = contrastive_loss_weight
93
+ self.is_mt = is_mt
94
+ self.use_enrollment_network = use_enrollment_network
95
+ self.spk_embedding_extraction_layer = spk_embedding_extraction_layer
96
+ self.num_speakers = num_speakers
97
+ self.sid_loss_weight = sid_loss_weight
98
+ self.attend_to_enrollment = attend_to_enrollment
99
+ self.use_enrollment_network = use_enrollment_network
100
+ self.uses_enrollments = uses_enrollments
101
+
102
+
103
+ _HIDDEN_STATES_START_POSITION = 2
contrastive_loss.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional
5
+ from torch import Tensor
6
+
7
+ class ContrastiveLoss(nn.Module):
8
+ def __init__(self, temperature=.25, distance_metric='cosine'):
9
+ super(ContrastiveLoss, self).__init__()
10
+ self.temperature = temperature
11
+ self.distance_metric = distance_metric
12
+
13
+ def compute_similarity(self, embeddings):
14
+ if self.distance_metric == 'cosine':
15
+ embeddings = F.normalize(embeddings, p=2, dim=-1) # [B, 2T, D]
16
+ sim = torch.matmul(embeddings, embeddings.transpose(-1, -2)) # [B, 2T, 2T]
17
+ else:
18
+ raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
19
+ return sim / self.temperature
20
+
21
+ def compute_cross_similarity(self, embeddings1, embeddings2):
22
+ """Compute similarity between two different embedding sets"""
23
+ if self.distance_metric == 'cosine':
24
+ embeddings1 = F.normalize(embeddings1, p=2, dim=-1) # [B, 2T, D]
25
+ embeddings2 = F.normalize(embeddings2, p=2, dim=-1) # [B, 2T, D]
26
+ sim = torch.matmul(embeddings1, embeddings2.transpose(-1, -2)) # [B, 2T, 2T]
27
+ else:
28
+ raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
29
+ return sim / self.temperature
30
+
31
+ def pairwise_and_no_diag(self, m):
32
+ m_i = m.unsqueeze(2) # [B, T, 1]
33
+ m_j = m.unsqueeze(1) # [B, 1, T]
34
+ out = m_i & m_j # [B, T, T]
35
+ diag = torch.eye(m.size(1), dtype=torch.bool, device=m.device).unsqueeze(0)
36
+ return out & ~diag
37
+
38
+ def forward(self, embeddings, anchors, enrollment_embeddings: Optional[Tensor] = None,
39
+ enrollment_embeddings_mask: Optional[Tensor] = None):
40
+ """
41
+ Args:
42
+ embeddings: [B, 2T, D] - main embeddings
43
+ anchors: [B, 2T] - boolean mask indicating anchor positions
44
+ enrollment_embeddings: Optional[B, 2T, D] - enrollment embeddings for positive pairs
45
+ enrollment_embeddings_mask: Optional[B, 2T] - boolean mask for valid enrollment positions
46
+ Returns:
47
+ Scalar contrastive loss
48
+ """
49
+ # Use enrollment embeddings if provided
50
+ if enrollment_embeddings is not None and enrollment_embeddings_mask is not None:
51
+ return self._forward_with_enrollment(embeddings, anchors, enrollment_embeddings, enrollment_embeddings_mask)
52
+ else:
53
+ # Fall back to original behavior
54
+ return self._forward_original(embeddings, anchors)
55
+
56
+ def _forward_with_enrollment(self, embeddings, anchors, enrollment_embeddings, enrollment_embeddings_mask):
57
+ """Forward pass using enrollment embeddings as positives"""
58
+ B, two_T, D = embeddings.shape
59
+ T = two_T // 2
60
+
61
+ # Compute similarity between main embeddings and enrollment embeddings
62
+ cross_sim = self.compute_cross_similarity(embeddings, enrollment_embeddings) # [B, 2T, 2T]
63
+
64
+ # Compute similarity within main embeddings for negatives
65
+ self_sim = self.compute_similarity(embeddings) # [B, 2T, 2T]
66
+
67
+ # Split anchor mask
68
+ m1 = anchors[:, :T] # [B, T]
69
+ m2 = anchors[:, T:] # [B, T]
70
+
71
+ # Split enrollment mask
72
+ enroll_m1 = enrollment_embeddings_mask[:, :T] # [B, T]
73
+ enroll_m2 = enrollment_embeddings_mask[:, T:] # [B, T]
74
+
75
+ # Create positive mask: anchor positions can match with corresponding enrollment positions
76
+ # First speaker (positions 0:T) matches with enrollment first speaker (positions 0:T)
77
+ pos_mask_1to1 = m1.unsqueeze(2) & enroll_m1.unsqueeze(1) # [B, T, T]
78
+ # Second speaker (positions T:2T) matches with enrollment second speaker (positions T:2T)
79
+ pos_mask_2to2 = m2.unsqueeze(2) & enroll_m2.unsqueeze(1) # [B, T, T]
80
+
81
+ # Build full positive mask
82
+ pos_mask = torch.cat([
83
+ torch.cat([pos_mask_1to1, torch.zeros_like(pos_mask_1to1)], dim=2), # [B, T, 2T]
84
+ torch.cat([torch.zeros_like(pos_mask_2to2), pos_mask_2to2], dim=2) # [B, T, 2T]
85
+ ], dim=1) # [B, 2T, 2T]
86
+
87
+ # Create negative mask: cross-speaker pairs within main embeddings
88
+ cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T]
89
+ neg_mask = torch.cat([
90
+ torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T]
91
+ torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T]
92
+ ], dim=1) # [B, 2T, 2T]
93
+
94
+ # Exclude self-pairs in negative mask
95
+ identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T]
96
+ neg_mask &= ~identity_mask
97
+
98
+ # Also exclude self-pairs in positive mask (diagonal elements)
99
+ pos_mask &= ~identity_mask
100
+
101
+ # Compute contrastive loss
102
+ if pos_mask.any():
103
+ # Get positive similarities from cross-similarity matrix
104
+ pos_sim = cross_sim[pos_mask] # [num_pos_pairs]
105
+ pos_exp = torch.exp(pos_sim) # [num_pos_pairs]
106
+
107
+ # Compute negative exponentials from self-similarity matrix
108
+ exp_self_sim = torch.exp(self_sim) # [B, 2T, 2T]
109
+ neg_exp_sum = torch.sum(exp_self_sim * neg_mask.float(), dim=2) # [B, 2T]
110
+
111
+ # Get the negative sums corresponding to each positive pair
112
+ pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3]
113
+ batch_idx = pos_indices[:, 0] # [num_pos_pairs]
114
+ row_idx = pos_indices[:, 1] # [num_pos_pairs]
115
+
116
+ # Get negative sums for each positive pair's anchor
117
+ neg_sums_for_pos = neg_exp_sum[batch_idx, row_idx] # [num_pos_pairs]
118
+
119
+ # Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair
120
+ denominators = pos_exp + neg_sums_for_pos # [num_pos_pairs]
121
+
122
+ # InfoNCE loss: -log(exp(pos) / denominator)
123
+ loss = -torch.log(pos_exp / denominators)
124
+ total_loss = loss.mean()
125
+ else:
126
+ # No positive pairs found, return zero loss
127
+ total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True)
128
+
129
+ return total_loss
130
+
131
+ def _forward_original(self, embeddings, pos_indicator_mask):
132
+ """Original forward pass for backward compatibility"""
133
+ B, two_T, D = embeddings.shape
134
+ T = two_T // 2
135
+ sim = self.compute_similarity(embeddings) # [B, 2T, 2T]
136
+
137
+ # Split input mask
138
+ m1 = pos_indicator_mask[:, :T] # [B, T]
139
+ m2 = pos_indicator_mask[:, T:] # [B, T]
140
+
141
+ # Positive mask (same speaker pairs, diagonal excluded)
142
+ pos_block1 = self.pairwise_and_no_diag(m1) # [B, T, T]
143
+ pos_block2 = self.pairwise_and_no_diag(m2) # [B, T, T]
144
+ pos_mask = torch.cat([
145
+ torch.cat([pos_block1, torch.zeros_like(pos_block1)], dim=2), # [B, T, 2T]
146
+ torch.cat([torch.zeros_like(pos_block2), pos_block2], dim=2) # [B, T, 2T]
147
+ ], dim=1) # [B, 2T, 2T]
148
+
149
+ # Negative mask (cross-speaker pairs where both are active)
150
+ cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T]
151
+ neg_mask = torch.cat([
152
+ torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T]
153
+ torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T]
154
+ ], dim=1) # [B, 2T, 2T]
155
+
156
+ # Identity mask (exclude [i, i] self-pairs)
157
+ identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T]
158
+ pos_mask &= ~identity_mask
159
+ neg_mask &= ~identity_mask
160
+
161
+ # Fully vectorized InfoNCE computation
162
+ if pos_mask.any():
163
+ # Compute exp(similarities) for numerical stability
164
+ exp_sim = torch.exp(sim) # [B, 2T, 2T]
165
+
166
+ # Get positive similarities
167
+ pos_sim = sim[pos_mask] # [num_pos_pairs]
168
+ pos_exp = torch.exp(pos_sim) # [num_pos_pairs]
169
+
170
+ # For each position, sum the exponentials of its negatives
171
+ neg_exp_avg = 10 * torch.mean(exp_sim * neg_mask.float(), dim=2) # [B, 2T]
172
+
173
+ # Get the negative sums corresponding to each positive pair
174
+ pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3]
175
+ batch_idx = pos_indices[:, 0] # [num_pos_pairs]
176
+ row_idx = pos_indices[:, 1] # [num_pos_pairs]
177
+
178
+ # Get negative sums for each positive pair's anchor
179
+ neg_avgs_for_pos = neg_exp_avg[batch_idx, row_idx] # [num_pos_pairs]
180
+
181
+ # Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair
182
+ denominators = pos_exp + neg_avgs_for_pos # [num_pos_pairs]
183
+
184
+ # InfoNCE loss: -log(exp(pos) / denominator)
185
+ loss = -torch.log(pos_exp / denominators)
186
+ total_loss = loss.mean()
187
+ else:
188
+ # No positive pairs found, return zero loss
189
+ total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True)
190
+ return total_loss
decoding.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
3
+ import itertools as it
4
+ from typing import List
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from transformers import LogitsProcessor, PreTrainedTokenizer
9
+
10
+
11
+ class CTCPrefixScore(object):
12
+ """Compute CTC label sequence scores
13
+
14
+ which is based on Algorithm 2 in WATANABE et al.
15
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
16
+ but extended to efficiently compute the label probabilities for multiple
17
+ hypotheses simultaneously
18
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
19
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
20
+ """
21
+
22
+ def __init__(self, x, blank, eos):
23
+ self.logzero = -1e10
24
+ self.blank = blank
25
+ self.eos = eos
26
+ self.input_length = x.shape[1]
27
+ self.batch_size = x.shape[0]
28
+ self.x = x
29
+ self.device = x.device
30
+
31
+ # Preallocate `r` and `xs` tensors
32
+ # `num_labels` will be set dynamically in __call__ but preallocated with maximum capacity
33
+ self.max_num_labels = x.shape[2] # Set to a max value that can be dynamically resized
34
+ self.r = torch.full((self.batch_size, self.input_length, 2, self.max_num_labels), self.logzero,
35
+ device=self.device)
36
+ self.xs = torch.full((self.batch_size, self.input_length, self.max_num_labels), self.logzero,
37
+ device=self.device)
38
+
39
+ def initial_state(self):
40
+ """Obtain an initial CTC state."""
41
+ # Create initial CTC state tensor and use in-place operations to fill
42
+ r = torch.full((self.batch_size, self.input_length, 2), self.logzero, device=self.device)
43
+ r[..., 1] = torch.cumsum(self.x[..., self.blank], dim=1)
44
+ s = torch.zeros((self.batch_size, 1), device=self.device)
45
+
46
+ return r, s
47
+
48
+ def _resize_tensors(self, number_of_current_samples, num_labels):
49
+ if self.r.shape[0] != number_of_current_samples:
50
+ self.r = self.r[:number_of_current_samples, ...]
51
+ self.xs = self.xs[:number_of_current_samples, ...]
52
+
53
+ if self.r.shape[3] != num_labels:
54
+ self.r = self.r[:, :, :, :num_labels].fill_(self.logzero)
55
+ self.xs = self.xs[:, :, :num_labels].fill_(self.logzero)
56
+ else:
57
+ self.r.fill_(self.logzero)
58
+ self.xs.fill_(self.logzero)
59
+
60
+ def _initialize_r(self, decoded_len):
61
+ mask = (decoded_len == 0)
62
+ self.r[mask, 0, 0, :] = self.xs[mask, 0]
63
+
64
+ def _compute_log_phi(self, r_sum, cs, last, decoded_len, r_prev):
65
+ # Expand r_sum for num_labels and initialize log_phi
66
+ log_phi = r_sum[..., None].expand(-1, -1, cs.shape[1])
67
+
68
+ # Create mask for cases where `decoded_len > 0` and to identify where `c == last[i]` for all `i`
69
+ non_zero_mask = (decoded_len > 0)
70
+ label_match_mask = (cs == last.unsqueeze(1))
71
+
72
+ # Update log_phi where both `decoded_len > 0` and `c == last[i]`
73
+ log_phi = torch.where((non_zero_mask.unsqueeze(1) & label_match_mask)[:, None, :], r_prev[..., 1:2], log_phi)
74
+ return log_phi
75
+
76
+ def _compute_log_psi(self, decoded_len, log_phi, x_current):
77
+ """This function computes forward probabilities log(r_t^n(h)), log(r_t^b(h)),
78
+ and log prefix probabilities log(psi) for all labels in the batch.
79
+
80
+ :param decoded_len: tensor of shape (batch_size,) containing the length of the decoded sequence
81
+ :param log_phi: tensor of shape (batch_size, input_length, num_labels) containing the forward probabilities
82
+ :param x_current: tensor of shape (batch_size, input_length, num_labels) containing the input frame
83
+
84
+ :return log_psi: tensor of shape (batch_size,num_labels) containing the log prefix probabilities
85
+ """
86
+ B, T, V = log_phi.shape
87
+ start = torch.clamp(decoded_len, min=1) # Ensure start is at least 1 to avoid out-of-bounds
88
+
89
+ # Initialize log_psi with the start position of r[:, start - 1, 0, :]
90
+ log_psi = self.r[torch.arange(B), start - 1, 0, :]
91
+
92
+ # Mask for handling sequence lengths based on decoded_len
93
+ mask_t = torch.arange(1, T, device=decoded_len.device).expand(B, T - 1) >= decoded_len.unsqueeze(1)
94
+
95
+ # Accumulate log_psi only up to the last valid time step for each sequence
96
+ log_psi = torch.logaddexp(log_psi, torch.logsumexp(
97
+ torch.where(mask_t.unsqueeze(-1), log_phi[:, :-1] + self.xs[:, 1:], self.logzero), dim=1))
98
+
99
+ start = torch.clamp(decoded_len, 1)
100
+
101
+ # TODO: Vectorize this loop by compute suffix xs and multiplying with log_phi
102
+ # xs = self.xs[:,1:,:].clone()
103
+ # xs_cum = torch.cumsum(xs, dim=1)
104
+ # xs_cum_expanded = xs_cum.unsqueeze(1).repeat(1, T-1, 1, 1)
105
+ # xs_u = (xs_cum_expanded - torch.nn.functional.pad(xs_cum[:,:-1,:], (0,0,1,0), value=0).unsqueeze(2).repeat(1, 1,T-1,1)).permute(0,2,1,3)
106
+ #
107
+ # phis_new = log_phi[:,:-1].clone()
108
+ # phis_new[:, 0] = torch.logaddexp(phis_new[:, 0], self.r[:, 0, 0, :])
109
+ # phis_new = phis_new.unsqueeze(1).repeat(1, T-1, 1, 1)
110
+ # causal_mask = torch.ones((T-1,T-1), dtype=torch.bool, device=self.device).tril().unsqueeze(0).unsqueeze(-1).repeat(B,1,1,1)
111
+ # mask = causal_mask & mask_t.unsqueeze(2).unsqueeze(-1)
112
+ # r_zero = torch.logsumexp(torch.where(mask, xs_u + phis_new, self.logzero), dim=2)
113
+ # self.r[:,1:,0] = r_zero
114
+
115
+ for t in range(start.min(), self.input_length):
116
+ should_decode = decoded_len <= t
117
+ self.r[:, t, 0] = torch.logaddexp(self.r[:, t - 1, 0],
118
+ log_phi[:, t - 1]) + self.xs[:, t]
119
+ self.r[:, t, 1] = (
120
+ torch.logaddexp(self.r[:, t - 1, 0], self.r[:, t - 1, 1]) + x_current[:, t, self.blank][:, None]
121
+ )
122
+ if ~should_decode.any():
123
+ self.r[:, t] = torch.where(should_decode.unsqueeze(-1).unsqueeze(-1), self.r[:, t], self.logzero)
124
+
125
+ return log_psi
126
+
127
+ def _update_log_psi_with_eos(self, log_psi, cs, r_sum):
128
+ # Update log_psi for eos positions
129
+ eos_mask = (cs == self.eos)
130
+ log_psi[eos_mask] = r_sum[:, -1].unsqueeze(1).expand_as(log_psi)[eos_mask]
131
+
132
+ # Exclude blank probabilities if eos is not the blank
133
+ if self.eos != self.blank:
134
+ blank_mask = (cs == self.blank)
135
+ log_psi[blank_mask] = self.logzero
136
+ return log_psi
137
+
138
+ def __call__(self, y, cs, decoded_len, samples_to_be_decoded, r_prev):
139
+ """Compute CTC prefix scores for next labels
140
+
141
+ :param y : prefix label sequence
142
+ :param cs : array of next labels
143
+ :param r_prev: previous CTC state
144
+ :return ctc_scores, ctc_states
145
+ """
146
+ # initialize CTC states
147
+ # output_length = y.shape[1] - 1 # ignore sos
148
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
149
+ # that corresponds to r_t^n(h) and r_t^b(h).
150
+
151
+ # Dynamically resize r and xs to match num_labels if necessary
152
+ num_labels = cs.shape[1]
153
+ number_of_current_samples = cs.shape[0]
154
+ self._resize_tensors(number_of_current_samples, num_labels)
155
+
156
+ # Create a view of the current input frame
157
+ x_current = self.x[samples_to_be_decoded]
158
+ self.xs = torch.gather(x_current, 2, cs.unsqueeze(1).expand(-1, self.input_length, -1))
159
+
160
+ # Initialize r for the first frame
161
+ self._initialize_r(decoded_len)
162
+
163
+ # prepare forward probabilities for the last label
164
+ r_sum = torch.logaddexp(r_prev[:, :, 0], r_prev[:, :, 1]) # log(r_t^n(g) + r_t^b(g))
165
+ last = y[:, -1]
166
+
167
+ # precompute log_phi
168
+ log_phi = self._compute_log_phi(r_sum, cs, last, decoded_len, r_prev)
169
+
170
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
171
+ # and log prefix probabilities log(psi)
172
+ log_psi = self._compute_log_psi(decoded_len, log_phi, x_current)
173
+
174
+ # get P(...eos|X) that ends with the prefix itself
175
+ log_psi = self._update_log_psi_with_eos(log_psi, cs, r_sum)
176
+
177
+ # return the log prefix probability and CTC states, where the label axis
178
+ # of the CTC states is moved to the first axis to slice it easily
179
+ return log_psi, self.r
180
+
181
+
182
+ class CTCRescorerLogitsProcessor(LogitsProcessor):
183
+ def __init__(
184
+ self,
185
+ encoder_logits: torch.FloatTensor,
186
+ encoder_output_lens: torch.Tensor,
187
+ blank_token_id: int,
188
+ pad_token_id: int,
189
+ eos_token_id: int,
190
+ bos_token_id: int,
191
+ tokenizer: PreTrainedTokenizer,
192
+ ctc_margin: int,
193
+ ctc_weight: float,
194
+ num_beams: int,
195
+ debug: bool = False,
196
+ ctc_tokens_to_score: int = 500
197
+ ):
198
+ super().__init__()
199
+ same_logits = torch.tensor(list((tokenizer.upper_cased_tokens.items())))
200
+
201
+ logits = torch.nn.functional.log_softmax(encoder_logits, dim=-1)
202
+ logits[..., same_logits[:, 1]] = logits[..., same_logits[:, 0]]
203
+
204
+ self.logits = logits
205
+
206
+ self.ctc_prefix_scorer = CTCPrefixScore(
207
+ self.logits,
208
+ blank_token_id,
209
+ eos_token_id,
210
+ )
211
+ self.batch_size = logits.shape[0]
212
+ self.input_length = logits.shape[1]
213
+ self.num_tokens = logits.shape[2]
214
+ self.device = logits.device
215
+ self.ctc_weight = ctc_weight
216
+ self.num_beams = num_beams
217
+ self.ctc_state_prev, self.ctc_score_prev = self.ctc_prefix_scorer.initial_state()
218
+ self.eos_token_id = eos_token_id
219
+ self.bos_token_id = bos_token_id
220
+ self.tokenizer = tokenizer
221
+ self.pad_token_id = pad_token_id
222
+ self.blank_token_id = blank_token_id
223
+ self.debug = False
224
+ self.first_timestamp_token_id = tokenizer.get_vocab()["<|0.00|>"]
225
+ self.tmp_ctc_scores = torch.empty((self.batch_size, self.num_tokens - 1), device=self.device)
226
+ self.tmp_ctc_states = torch.empty((self.batch_size, self.num_tokens - 1, self.input_length, 2),
227
+ device=self.device)
228
+ self.ctc_tokens_to_score = ctc_tokens_to_score
229
+
230
+ def analyze_predictions(self,
231
+ scores, ctc_scores, next_token_scores, input_ids, k=10):
232
+ print("\n" + "#" * 100)
233
+
234
+ batch_size = input_ids.shape[0]
235
+
236
+ best_att_ids = scores.topk(k=k, dim=1)
237
+ ctc_scores[:, self.first_timestamp_token_id:] = self.ctc_prefix_scorer.logzero
238
+ best_ctc_ids = ctc_scores.topk(k=k, dim=1)
239
+ best_ids = next_token_scores.topk(k=k, dim=1)
240
+
241
+ decoded_prefixes = self.tokenizer.batch_decode(
242
+ input_ids, decode_with_timestamps=True, skip_special_tokens=False
243
+ )
244
+
245
+ def prepare_and_decode(best_ids_tensor):
246
+ new_tensor = torch.zeros((batch_size, k * 2), dtype=torch.long)
247
+ new_tensor[:, 0::2] = best_ids_tensor.indices
248
+ new_tensor[:, 1::2] = self.tokenizer.vocab['#']
249
+
250
+ # Flatten to (batch_size * k, 2)
251
+ flat_tensor = new_tensor.view(-1, 2)
252
+ decoded = self.tokenizer.batch_decode(
253
+ flat_tensor, decode_with_timestamps=True, skip_special_tokens=False
254
+ )
255
+ # Reshape back to (batch_size, k)
256
+ decoded = [(decoded[i * k:(i + 1) * k]) for i in range(batch_size)]
257
+ return decoded
258
+
259
+ decoded_att = prepare_and_decode(best_att_ids)
260
+ decoded_ctc = prepare_and_decode(best_ctc_ids)
261
+ decoded_next = prepare_and_decode(best_ids)
262
+
263
+ for idx in range(batch_size):
264
+ print("-" * 80)
265
+ print(f"HYPOTHESIS {idx}")
266
+ print("\nPREFIX:")
267
+ print(decoded_prefixes[idx])
268
+
269
+ def print_with_pandas(tokens, scores, title):
270
+ df = pd.DataFrame([tokens, [f"{s.item():.2f}" for s in scores]])
271
+ df.index = [f"{title}", "Score"]
272
+ print(f"\n{title}:")
273
+ print(df.to_string(index=True, header=False))
274
+
275
+ print_with_pandas(decoded_att[idx], best_att_ids.values[idx], "ATT_TOKENS")
276
+ print_with_pandas(decoded_ctc[idx], best_ctc_ids.values[idx], "CTC_TOKENS")
277
+ print_with_pandas(decoded_next[idx], best_ids.values[idx], "NEXT_TOKENS")
278
+
279
+ print(f"\nCTC_EOS: {ctc_scores[idx, self.tokenizer.eos_token_id].item():.2f}")
280
+ print()
281
+
282
+ print("#" * 100)
283
+
284
+ def update_state(self, best_ids, beam_idx):
285
+ mask = best_ids < self.first_timestamp_token_id
286
+ self.ctc_state_prev = torch.where(mask.unsqueeze(-1).unsqueeze(-1),
287
+ self.tmp_ctc_states[beam_idx, best_ids],
288
+ self.ctc_state_prev[beam_idx])
289
+ self.ctc_score_prev = torch.where(mask.unsqueeze(-1),
290
+ self.tmp_ctc_scores[beam_idx, best_ids].unsqueeze(-1),
291
+ self.ctc_score_prev[beam_idx])
292
+
293
+ def __call__(self, input_ids_orig: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
294
+ input_ids = input_ids_orig.clone()
295
+
296
+ # Remove prefix from CTC scoring
297
+ if (input_ids[:, 0] != self.bos_token_id).any():
298
+ input_ids = torch.stack(
299
+ [row[(row == self.bos_token_id).nonzero(as_tuple=True)[0].item():] for row in input_ids])
300
+
301
+ # Remove task/lang/timestamp tokens from input_ids
302
+ input_prefix_len = len(self.tokenizer.prefix_tokens)
303
+ if input_prefix_len > 1:
304
+ input_ids = input_ids[:, input_prefix_len - 1:]
305
+
306
+ # Setup the first token to be the blank token(sos)
307
+ input_ids[:, 0] = self.blank_token_id
308
+
309
+ # If there is last token in input_ids timestamp replicate last non-timestamp token which could be potentially even the first token
310
+ decoded_len = torch.logical_and(input_ids <= self.first_timestamp_token_id,
311
+ input_ids != self.blank_token_id).sum(dim=1)
312
+ mask = torch.logical_and(input_ids[:, -1] >= self.first_timestamp_token_id,
313
+ input_ids[:, -1] != self.blank_token_id)
314
+ last_non_timestamp_token = torch.gather(input_ids, 1,
315
+ torch.logical_or(input_ids < self.first_timestamp_token_id,
316
+ input_ids == self.blank_token_id).sum(dim=1,
317
+ keepdim=True) - 1)
318
+ input_ids[mask, -1] = last_non_timestamp_token[mask, 0]
319
+
320
+ # If there is no eos token in the last position, we need to continue decoding
321
+ to_be_decoded = input_ids[:, -1] != self.eos_token_id
322
+ self.tmp_ctc_scores[:] = self.ctc_prefix_scorer.logzero
323
+
324
+ input_ids_local = input_ids[to_be_decoded]
325
+ ids_to_score = torch.topk(scores[:, :self.first_timestamp_token_id], k=self.ctc_tokens_to_score).indices
326
+
327
+ # always score EOS token if not present put on position of last id
328
+ is_eos_present = (ids_to_score == self.eos_token_id).any(dim=1)
329
+ ids_to_score[~is_eos_present, self.ctc_tokens_to_score - 1] = self.eos_token_id
330
+
331
+ decoded_len_local = decoded_len[to_be_decoded]
332
+
333
+ ctc_scores_local, ctc_states_local = self.ctc_prefix_scorer(input_ids_local, ids_to_score[to_be_decoded],
334
+ decoded_len_local, to_be_decoded,
335
+ self.ctc_state_prev[to_be_decoded])
336
+
337
+ # As the CTC scorer might run on subset of samples, we need to scatter the results back to the original batch
338
+ self.tmp_ctc_scores[to_be_decoded] = (self.tmp_ctc_scores[to_be_decoded]
339
+ .scatter(1, ids_to_score[to_be_decoded], ctc_scores_local))
340
+ self.tmp_ctc_states[to_be_decoded] = (self.tmp_ctc_states[to_be_decoded].permute(0, 2, 3, 1)
341
+ .scatter(3, ids_to_score[to_be_decoded].unsqueeze(1).unsqueeze(1)
342
+ .repeat(1, *ctc_states_local.shape[1:3], 1), ctc_states_local)
343
+ .permute(0, 3, 1, 2))
344
+
345
+ # Set the CTC score for the timestamp tokens to the maximum to prefer them over the rest
346
+ self.tmp_ctc_scores[:, self.first_timestamp_token_id:] = self.tmp_ctc_scores.max(dim=1).values[:, None]
347
+ ctc_scores = self.tmp_ctc_scores - self.ctc_score_prev
348
+
349
+ next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
350
+
351
+ if self.debug:
352
+ self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids_orig)
353
+
354
+ return next_token_scores
355
+
356
+
357
+ class LogSoftmaxProcessor(LogitsProcessor):
358
+ def __init__(
359
+ self,
360
+ ):
361
+ super().__init__()
362
+
363
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
364
+ scores = torch.nn.functional.log_softmax(scores, dim=-1)
365
+ return scores
366
+
367
+
368
+ class GreedyCTCDecoder(torch.nn.Module):
369
+ def __init__(self, tokenizer, blank=0):
370
+ super().__init__()
371
+ self.blank = blank
372
+ self.tokenizer = tokenizer
373
+
374
+ def forward(self, emission: torch.Tensor) -> List[str]:
375
+ """Given a sequence emission over labels, get the best path
376
+ Args:
377
+ emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
378
+
379
+ Returns:
380
+ List[str]: The resulting transcript
381
+ """
382
+ indices = torch.argmax(emission, dim=-1) # [num_seq,]
383
+ indices = [torch.unique_consecutive(index, dim=-1) for index in indices]
384
+ indices = [index[index != self.blank] for index in indices]
385
+ indices = torch.nn.utils.rnn.pad_sequence(indices, batch_first=True,
386
+ padding_value=self.tokenizer.pad_token_id)
387
+ indices[indices >= len(self.tokenizer)] = self.tokenizer.unk_token_id
388
+ return indices
389
+
390
+
391
+ def ctc_greedy_decode(logits: torch.Tensor, blank, pad_token_id) -> torch.Tensor:
392
+ idxs = torch.argmax(logits, dim=-1)
393
+ for i, prediction in enumerate(idxs):
394
+ deduplicated = [k for k, g in it.groupby(prediction) if k != blank]
395
+ idxs[i, : len(deduplicated)] = torch.tensor(deduplicated)
396
+ idxs[i, len(deduplicated):] = pad_token_id
397
+ return idxs
encoder.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
4
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES
5
+
6
+ from .FDDT import FDDT
7
+ from .config import DiCoWConfig
8
+ from .SCBs import SpeakerCommunicationBlock
9
+
10
+
11
+ class DiCoWEncoder(WhisperEncoder):
12
+ config_class = DiCoWConfig
13
+
14
+ def __init__(self, config: DiCoWConfig):
15
+ super().__init__(config)
16
+ self.ctc_weight = config.ctc_weight
17
+ if config.additional_layer and self.ctc_weight > 0.0:
18
+ self.additional_layer = WhisperEncoderLayer(config)
19
+ if config.additional_self_attention_layer and self.ctc_weight > 0.0:
20
+ self.additional_self_attention_layer = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
21
+ embed_dim=config.d_model,
22
+ num_heads=config.encoder_attention_heads,
23
+ dropout=config.attention_dropout,
24
+ config=config,
25
+ )
26
+ if config.sub_sample and self.ctc_weight > 0.0:
27
+ self.subsample_conv1 = nn.Conv1d(
28
+ in_channels=config.d_model,
29
+ out_channels=config.d_model,
30
+ kernel_size=3,
31
+ stride=2,
32
+ padding=1,
33
+ bias=False,
34
+ )
35
+ self.subsample_conv2 = nn.Conv1d(
36
+ in_channels=config.d_model,
37
+ out_channels=config.d_model,
38
+ kernel_size=3,
39
+ stride=2,
40
+ padding=1,
41
+ bias=False,
42
+ )
43
+ if self.ctc_weight > 0.0:
44
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size + 1, bias=False)
45
+ self.final_dropout = nn.Dropout(config.final_dropout)
46
+ if config.use_fddt:
47
+ num_fddts = self.config.apply_fddt_to_n_layers if self.config.apply_fddt_to_n_layers != -1 else len(
48
+ self.layers)
49
+ self.initial_fddt = FDDT(config,
50
+ d_model=config.d_model,
51
+ non_target_rate=config.non_target_fddt_value,
52
+ is_diagonal=config.fddt_is_diagonal,
53
+ bias_only=config.fddt_bias_only,
54
+ use_silence=config.fddt_use_silence,
55
+ use_target=config.fddt_use_target,
56
+ use_overlap=config.fddt_use_overlap,
57
+ use_non_target=config.fddt_use_non_target,
58
+ use_interaction=False,
59
+ )
60
+ num_scbs = (self.config.scb_layers if self.config.scb_layers != -1 else len(
61
+ self.layers)) if self.config.is_mt else 0
62
+ self.fddts = nn.ModuleList([
63
+ FDDT(config,
64
+ d_model=config.d_model,
65
+ non_target_rate=1.0,
66
+ is_diagonal=config.fddt_is_diagonal,
67
+ bias_only=config.fddt_bias_only,
68
+ use_silence=config.fddt_use_silence,
69
+ use_target=config.fddt_use_target,
70
+ use_overlap=config.fddt_use_overlap,
71
+ use_non_target=config.fddt_use_non_target,
72
+ use_interaction=i < num_scbs,
73
+ )
74
+ for i in range(num_fddts)
75
+ ])
76
+ self.first_task_token = self.config.vocab_size - 30 * 50 - 1 - 6 # 30 seconds of 50 Hz timestamps -1 to get to 0.0 and -6 number of tasks
77
+ self.post_init()
78
+
79
+ def encode_enrollment(
80
+ self,
81
+ input_features,
82
+ num_layers_to_apply,
83
+ head_mask=None,
84
+ stno_mask=None,
85
+ ):
86
+ # For MT-ASR the input has shape (B X S) x F x T
87
+ # we can use torch.view(B, S, F, -1) to obtain
88
+ # new tensor with speaker dim
89
+ expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
90
+ if input_features.shape[-1] != expected_seq_length:
91
+ if input_features.shape[-1] > expected_seq_length:
92
+ return CausalLMOutput(
93
+ logits=None,
94
+ hidden_states=None,
95
+ attentions=None,
96
+ )
97
+ else:
98
+ raise ValueError(
99
+ f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
100
+ )
101
+
102
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
103
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
104
+
105
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
106
+ embed_pos = self.embed_positions.weight
107
+
108
+ if self.config.use_fddt:
109
+ inputs_embeds = self.initial_fddt(inputs_embeds, stno_mask)
110
+
111
+ hidden_states = inputs_embeds + embed_pos
112
+
113
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
114
+
115
+ # check if head_mask has a correct number of layers specified if desired
116
+ if head_mask is not None:
117
+ assert head_mask.size()[0] == (
118
+ len(self.layers)
119
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
120
+
121
+ for idx, encoder_layer in enumerate(self.layers[:num_layers_to_apply]):
122
+ if self.config.use_fddt and idx < len(self.fddts):
123
+ hidden_states = self.fddts[idx](hidden_states, stno_mask)
124
+ if self.gradient_checkpointing and self.training:
125
+ layer_outputs = self._gradient_checkpointing_func(
126
+ encoder_layer.__call__,
127
+ hidden_states,
128
+ None,
129
+ (head_mask[idx] if head_mask is not None else None),
130
+ )
131
+ else:
132
+ layer_outputs = encoder_layer(
133
+ hidden_states,
134
+ None,
135
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
136
+ )
137
+
138
+ hidden_states = layer_outputs[0]
139
+
140
+ return hidden_states
141
+
142
+ @classmethod
143
+ def _load_pretrained_model(
144
+ cls,
145
+ model,
146
+ state_dict,
147
+ loaded_keys,
148
+ resolved_archive_file,
149
+ pretrained_model_name_or_path,
150
+ **kwargs
151
+ ):
152
+ for key in list(state_dict.keys()):
153
+ if key.startswith("encoder."):
154
+ state_dict[key[8:]] = state_dict.pop(key)
155
+ loaded_keys.remove(key)
156
+ loaded_keys.append(key[8:])
157
+ output = super()._load_pretrained_model(
158
+ model,
159
+ state_dict,
160
+ loaded_keys,
161
+ resolved_archive_file,
162
+ pretrained_model_name_or_path,
163
+ **kwargs
164
+ )
165
+ return output
166
+
167
+ def get_loss(self, logits, labels):
168
+ if labels.max() >= self.config.vocab_size:
169
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
170
+ if self.config.remove_timestamps_from_ctc:
171
+ labels = torch.nn.utils.rnn.pad_sequence([label[label < self.first_task_token] for label in labels],
172
+ padding_value=-100).T
173
+ input_lengths = torch.full((logits.shape[0],), fill_value=logits.shape[1],
174
+ device=logits.device)
175
+
176
+ # assuming that padded tokens are filled with -100
177
+ # when not being attended to
178
+ labels_mask = labels >= 0
179
+ target_lengths = labels_mask.sum(-1)
180
+ # flattened_targets = labels_enc.masked_select(labels_mask)
181
+
182
+ # ctc_loss doesn't support fp16
183
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
184
+
185
+ with torch.backends.cudnn.flags(enabled=True):
186
+ ctc_loss = nn.functional.ctc_loss(
187
+ log_probs,
188
+ labels,
189
+ input_lengths,
190
+ target_lengths,
191
+ blank=logits.shape[-1] - 1,
192
+ reduction=self.config.ctc_loss_reduction,
193
+ zero_infinity=True,
194
+ )
195
+ return ctc_loss
196
+
197
+ def forward(
198
+ self,
199
+ input_features,
200
+ attention_mask=None,
201
+ head_mask=None,
202
+ output_attentions=None,
203
+ output_hidden_states=None,
204
+ return_dict=None,
205
+ stno_mask=None,
206
+ per_group_sizes=None
207
+ ):
208
+ # For MT-ASR the input has shape (B X S) x F x T
209
+ # we can use torch.view(B, S, F, -1) to obtain
210
+ # new tensor with speaker dim
211
+ expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
212
+ if input_features.shape[-1] != expected_seq_length:
213
+ if input_features.shape[-1] > expected_seq_length:
214
+ return CausalLMOutput(
215
+ logits=None,
216
+ hidden_states=None,
217
+ attentions=None,
218
+ )
219
+ else:
220
+ raise ValueError(
221
+ f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
222
+ )
223
+
224
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
225
+ output_hidden_states = (
226
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
227
+ )
228
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
229
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
230
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
231
+
232
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
233
+ embed_pos = self.embed_positions.weight
234
+
235
+ if self.config.use_fddt:
236
+ inputs_embeds = self.initial_fddt(inputs_embeds, stno_mask)
237
+
238
+ hidden_states = inputs_embeds + embed_pos
239
+
240
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
241
+
242
+ encoder_states = () if output_hidden_states else None
243
+ all_attentions = () if output_attentions else None
244
+
245
+ # check if head_mask has a correct number of layers specified if desired
246
+ if head_mask is not None:
247
+ assert head_mask.size()[0] == (
248
+ len(self.layers)
249
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
250
+
251
+ for idx, encoder_layer in enumerate(self.layers):
252
+ if output_hidden_states:
253
+ encoder_states = encoder_states + (hidden_states,)
254
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
255
+ to_drop = False
256
+ if self.training:
257
+ dropout_probability = torch.rand([])
258
+ if dropout_probability < self.layerdrop: # skip the layer
259
+ to_drop = True
260
+
261
+ if self.config.use_fddt and idx < len(self.fddts):
262
+ hidden_states = self.fddts[idx](hidden_states, stno_mask)
263
+
264
+ if to_drop:
265
+ layer_outputs = (None, None)
266
+ else:
267
+ if self.gradient_checkpointing and self.training:
268
+ layer_outputs = self._gradient_checkpointing_func(
269
+ encoder_layer.__call__,
270
+ hidden_states,
271
+ None,
272
+ (head_mask[idx] if head_mask is not None else None),
273
+ output_attentions,
274
+ )
275
+ else:
276
+ layer_outputs = encoder_layer(
277
+ hidden_states,
278
+ None,
279
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
280
+ output_attentions=output_attentions,
281
+ )
282
+
283
+ hidden_states = layer_outputs[0]
284
+
285
+ if output_attentions:
286
+ all_attentions = all_attentions + (layer_outputs[1],)
287
+
288
+ hidden_states = self.layer_norm(hidden_states)
289
+ if output_hidden_states:
290
+ encoder_states = encoder_states + (hidden_states,)
291
+
292
+ if not return_dict:
293
+ outputs = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
294
+ else:
295
+ outputs = BaseModelOutput(
296
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
297
+ )
298
+
299
+ if hasattr(self, "additional_layer"):
300
+ inter_output, = self.additional_layer(
301
+ outputs.last_hidden_state,
302
+ attention_mask=None,
303
+ output_attentions=output_attentions,
304
+ layer_head_mask=None,
305
+ )
306
+ elif hasattr(self, "additional_self_attention_layer"):
307
+ inter_output, _, __ = self.additional_self_attention_layer(
308
+ outputs.last_hidden_state,
309
+ attention_mask=None,
310
+ output_attentions=output_attentions,
311
+ layer_head_mask=None,
312
+ )
313
+ else:
314
+ inter_output = outputs.last_hidden_state
315
+
316
+ inter_output = self.final_dropout(inter_output)
317
+ if hasattr(self, "subsample_conv2"):
318
+ inter_output = self.subsample_conv2(self.subsample_conv1(inter_output.transpose(1, 2))).transpose(1, 2)
319
+ if self.ctc_weight > 0.0:
320
+ logits = self.lm_head(inter_output)
321
+ else:
322
+ logits = None
323
+
324
+ return CausalLMOutput(
325
+ logits=logits,
326
+ hidden_states=outputs.hidden_states,
327
+ attentions=outputs.attentions,
328
+ )
generation.py ADDED
@@ -0,0 +1,1808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+ from typing import Iterator
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+ from decimal import Decimal, ROUND_HALF_UP
14
+
15
+ from transformers import LogitsProcessorList, SuppressTokensLogitsProcessor, \
16
+ SuppressTokensAtBeginLogitsProcessor
17
+ from transformers.generation.configuration_utils import GenerationConfig
18
+ from transformers.generation.configuration_utils import GenerationMode
19
+ from transformers.generation.logits_process import (
20
+ LogitsProcessorList,
21
+ SuppressTokensAtBeginLogitsProcessor,
22
+ SuppressTokensLogitsProcessor, )
23
+ from transformers.generation.logits_process import WhisperNoSpeechDetection
24
+ from transformers.generation.stopping_criteria import (
25
+ StoppingCriteriaList,
26
+ )
27
+ from transformers.generation.utils import GenerateBeamOutput, BeamScorer, GenerateBeamDecoderOnlyOutput, \
28
+ stack_model_outputs, GenerateBeamEncoderDecoderOutput, _split_model_inputs, GenerateNonBeamOutput, \
29
+ GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
30
+ from transformers.modeling_outputs import BaseModelOutput
31
+ from transformers.models.whisper.modeling_whisper import (
32
+ WhisperForConditionalGeneration,
33
+ )
34
+ from transformers.models.whisper.generation_whisper import _get_attr_from_logit_processors, _pad_to_max_length
35
+ from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
36
+ from transformers.utils import logging
37
+
38
+ from .utils import WhisperTimeStampLogitsProcessorCustom
39
+ from .decoding import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
40
+
41
+ logging.set_verbosity_debug()
42
+ logger = logging.get_logger("transformers")
43
+
44
+
45
+ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
46
+ def _prepare_encoder_decoder_kwargs_for_generation(
47
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name, generation_config,
48
+ ) -> Dict[str, Any]:
49
+ # self.encoder_output_lens = self._get_feat_extract_output_lengths(
50
+ # model_kwargs['attention_mask_enc'].sum(dim=1)
51
+ # ).int()
52
+ generation_config.output_hidden_states = True
53
+
54
+ # pylint: disable=no-memberva
55
+ model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
56
+ inputs_tensor, model_kwargs, model_input_name, generation_config
57
+ )
58
+ if "is_valid" in model_kwargs:
59
+ for key in ['decoder_input_ids', 'stno_mask', 'labels', 'upp_labels', 'attention_mask', 'attention_mask_enc']:
60
+ if key in model_kwargs:
61
+ model_kwargs[key] = model_kwargs[key][model_kwargs['is_valid']]
62
+ model_kwargs['encoder_outputs']['logits'] = model_kwargs['encoder_outputs']['logits'][model_kwargs['is_valid']]
63
+ hidden_states = []
64
+ for layer in range(len(model_kwargs['encoder_outputs']['hidden_states'])):
65
+ hidden_states.append(model_kwargs['encoder_outputs']['hidden_states'][layer][model_kwargs['is_valid']])
66
+ model_kwargs['encoder_outputs']['hidden_states'] = tuple(hidden_states)
67
+ model_kwargs.pop("is_valid")
68
+ self.encoder_logits = model_kwargs["encoder_outputs"].logits
69
+
70
+ return model_kwargs
71
+
72
+ def _prepare_decoder_input_ids_for_generation(
73
+ self,
74
+ batch_size: int,
75
+ model_input_name: str,
76
+ model_kwargs: Dict[str, torch.Tensor],
77
+ decoder_start_token_id: torch.Tensor,
78
+ device: torch.device = None,
79
+ ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
80
+ batch_size = model_kwargs['decoder_input_ids'].shape[0]
81
+ out = super()._prepare_decoder_input_ids_for_generation(
82
+ batch_size,
83
+ model_input_name,
84
+ model_kwargs,
85
+ decoder_start_token_id,
86
+ device,
87
+ )
88
+ return out
89
+
90
+ @staticmethod
91
+ def _expand_inputs_for_generation(
92
+ expand_size: int = 1,
93
+ is_encoder_decoder: bool = False,
94
+ input_ids: Optional[torch.LongTensor] = None,
95
+ **model_kwargs,
96
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
97
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
98
+
99
+ def _expand_dict_for_generation(dict_to_expand):
100
+ for key in dict_to_expand:
101
+ if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) and key != "loss":
102
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
103
+ return dict_to_expand
104
+
105
+ if input_ids is not None:
106
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
107
+
108
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
109
+
110
+ if is_encoder_decoder:
111
+ if model_kwargs.get("encoder_outputs") is None:
112
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
113
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
114
+ if "hidden_states" in model_kwargs["encoder_outputs"]:
115
+ model_kwargs["encoder_outputs"]["hidden_states"] = tuple(
116
+ hidden_state.repeat_interleave(expand_size, dim=0) for hidden_state in
117
+ model_kwargs["encoder_outputs"]["hidden_states"]
118
+ )
119
+
120
+ return input_ids, model_kwargs
121
+
122
+ def generate(
123
+ self,
124
+ input_features: Optional[torch.Tensor] = None,
125
+ generation_config: Optional[GenerationConfig] = None,
126
+ logits_processor: Optional[LogitsProcessorList] = None,
127
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
128
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
129
+ synced_gpus: bool = False,
130
+ return_timestamps: Optional[bool] = None,
131
+ task: Optional[str] = None,
132
+ language: Optional[str] = None,
133
+ is_multilingual: Optional[bool] = None,
134
+ prompt_ids: Optional[torch.Tensor] = None,
135
+ prompt_condition_type: Optional[str] = None, # first-segment, all-segments
136
+ condition_on_prev_tokens: Optional[bool] = None,
137
+ temperature: Optional[Union[float, Tuple[float, ...]]] = None,
138
+ compression_ratio_threshold: Optional[float] = None,
139
+ logprob_threshold: Optional[float] = None,
140
+ no_speech_threshold: Optional[float] = None,
141
+ num_segment_frames: Optional[int] = None,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ time_precision: float = 0.02,
144
+ return_token_timestamps: Optional[bool] = None,
145
+ return_segments: bool = False,
146
+ return_dict_in_generate: Optional[bool] = None,
147
+ assistant_model: Optional["PreTrainedModel"] = None,
148
+ **kwargs,
149
+ ):
150
+ if condition_on_prev_tokens:
151
+ raise NotImplementedError("Current version does not support conditioning")
152
+
153
+ gen_c, _ = self._prepare_generation_config(generation_config, **kwargs)
154
+ gen_mode = gen_c.get_generation_mode(assistant_model)
155
+
156
+ if gen_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.BEAM_SEARCH]:
157
+ raise ValueError(
158
+ f"Provided generation mode {gen_mode} is not supported"
159
+ f" for WhisperForConditionalGeneration with joint CTC decoding")
160
+
161
+ if "stno_mask" in kwargs:
162
+ self.stno_mask = kwargs["stno_mask"]
163
+ if "encoder_outputs" in kwargs:
164
+ self.encoder_logits = kwargs["encoder_outputs"].logits
165
+ # pylint: disable=no-member
166
+ # 0. deprecate old inputs
167
+ if "inputs" in kwargs:
168
+ input_features = kwargs.pop("inputs")
169
+ warnings.warn(
170
+ "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
171
+ FutureWarning,
172
+ )
173
+
174
+ # 1. prepare generation config
175
+ generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
176
+
177
+ # 2. set global generate variables
178
+ input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
179
+ num_segment_frames = input_stride * self.config.max_source_positions
180
+ batch_size, total_input_frames = self._retrieve_total_input_frames(
181
+ input_features=input_features, input_stride=input_stride, kwargs=kwargs
182
+ )
183
+ is_shortform = total_input_frames <= num_segment_frames
184
+
185
+ if is_shortform:
186
+ # warn user of ignored inputs
187
+ self._maybe_warn_unused_inputs(
188
+ condition_on_prev_tokens=condition_on_prev_tokens,
189
+ temperature=temperature,
190
+ compression_ratio_threshold=compression_ratio_threshold,
191
+ logprob_threshold=logprob_threshold,
192
+ no_speech_threshold=no_speech_threshold,
193
+ total_input_frames=total_input_frames,
194
+ )
195
+
196
+ # 3. Make sure generation config is correctly set
197
+ # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
198
+ self._set_return_outputs(
199
+ return_dict_in_generate=return_dict_in_generate,
200
+ return_token_timestamps=return_token_timestamps,
201
+ is_shortform=is_shortform,
202
+ logprob_threshold=logprob_threshold,
203
+ generation_config=generation_config,
204
+ )
205
+ self._set_return_timestamps(
206
+ return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
207
+ )
208
+ self._set_language_and_task(
209
+ language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
210
+ )
211
+ self._set_num_frames(
212
+ return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
213
+ )
214
+ self._set_thresholds_and_condition(
215
+ generation_config=generation_config,
216
+ logprob_threshold=logprob_threshold,
217
+ compression_ratio_threshold=compression_ratio_threshold,
218
+ no_speech_threshold=no_speech_threshold,
219
+ condition_on_prev_tokens=condition_on_prev_tokens,
220
+ )
221
+ self._set_prompt_condition_type(
222
+ generation_config=generation_config,
223
+ prompt_condition_type=prompt_condition_type,
224
+ )
225
+
226
+ # pass self.config for backward compatibility
227
+ init_tokens = self._retrieve_init_tokens(
228
+ input_features,
229
+ batch_size=batch_size,
230
+ generation_config=generation_config,
231
+ config=self.config,
232
+ num_segment_frames=num_segment_frames,
233
+ kwargs=kwargs,
234
+ )
235
+ # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
236
+ # where the input ids are handled explicitly by the generate method
237
+ self._check_decoder_input_ids(kwargs=kwargs)
238
+
239
+ # 3. Retrieve logits processors
240
+ device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
241
+ begin_index = init_tokens.shape[1]
242
+ logits_processor = self._retrieve_logit_processors(
243
+ generation_config=generation_config,
244
+ logits_processor=logits_processor,
245
+ begin_index=begin_index, # begin index is index of first generated decoder token
246
+ is_shortform=is_shortform,
247
+ num_beams=kwargs.get("num_beams", 1),
248
+ device=device,
249
+ )
250
+
251
+ # 5. If we're in shortform mode, simple generate the whole input at once and return the output
252
+ if is_shortform:
253
+ if temperature is not None:
254
+ generation_config.temperature = temperature
255
+
256
+ decoder_input_ids = kwargs.pop("decoder_input_ids", None)
257
+ if decoder_input_ids is None:
258
+ decoder_input_ids = init_tokens
259
+
260
+ if prompt_ids is not None:
261
+ decoder_input_ids = torch.cat(
262
+ [prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
263
+ )
264
+
265
+ max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
266
+ if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
267
+ raise ValueError(
268
+ f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
269
+ f"is {max_new_tokens}. Thus, the combined length of "
270
+ f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
271
+ f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
272
+ "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
273
+ f"so that their combined length is less than {self.config.max_target_positions}."
274
+ )
275
+
276
+ outputs = super().generate(
277
+ input_features,
278
+ generation_config=generation_config,
279
+ logits_processor=logits_processor,
280
+ stopping_criteria=stopping_criteria,
281
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
282
+ synced_gpus=synced_gpus,
283
+ decoder_input_ids=decoder_input_ids,
284
+ **kwargs,
285
+ )
286
+
287
+ if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"):
288
+ outputs["token_timestamps"] = self._extract_token_timestamps(
289
+ outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames
290
+ )
291
+
292
+ # print("\n".join(self.tokenizer.batch_decode(outputs,skip_special_tokens=True, decode_with_timestamps=True)))
293
+ return outputs
294
+
295
+ # 6. Else we're in longform mode which is more complex.
296
+ # We need to chunk the audio input depending on when the model generates timestamp tokens
297
+
298
+ # 6.1 Set and retrieve global longform generation variables
299
+ self._set_condition_on_prev_tokens(
300
+ condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
301
+ )
302
+
303
+ timestamp_begin = generation_config.no_timestamps_token_id + 1
304
+ temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
305
+ temperature = temperatures[0]
306
+ batch_size = input_features.shape[0]
307
+
308
+ max_frames, seek = self._retrieve_max_frames_and_seek(
309
+ batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames
310
+ )
311
+
312
+ # 6.2 Preppare running variables, list for generation
313
+ cur_bsz = batch_size
314
+ current_segments = self._prepare_segments(
315
+ prompt_ids=prompt_ids,
316
+ batch_size=batch_size,
317
+ generation_config=generation_config,
318
+ )
319
+
320
+ batch_idx_map = list(range(batch_size))
321
+ do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)]
322
+
323
+ # 6.2 Transcribe audio until we reach the end of all input audios
324
+ while (seek < max_frames).any():
325
+ # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
326
+ # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
327
+ # to know which original audio is being decoded
328
+ # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
329
+ input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
330
+ input_features=input_features,
331
+ seek=seek,
332
+ max_frames=max_frames,
333
+ cur_bsz=cur_bsz,
334
+ batch_idx_map=batch_idx_map,
335
+ )
336
+ time_offset = seek * time_precision / input_stride
337
+ seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
338
+
339
+ # 6.4 cut out next 30s segment from input features
340
+ segment_input = self._get_input_segment(
341
+ input_features=input_features,
342
+ seek=seek,
343
+ seek_num_frames=seek_num_frames,
344
+ num_segment_frames=num_segment_frames,
345
+ cur_bsz=cur_bsz,
346
+ batch_idx_map=batch_idx_map,
347
+ )
348
+
349
+ # 6.5 prepare decoder input ids
350
+ suppress_tokens = _get_attr_from_logit_processors(
351
+ logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
352
+ )
353
+ decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
354
+ cur_bsz=cur_bsz,
355
+ init_tokens=init_tokens,
356
+ current_segments=current_segments,
357
+ batch_idx_map=batch_idx_map,
358
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
359
+ prompt_ids=prompt_ids,
360
+ generation_config=generation_config,
361
+ config=self.config,
362
+ device=segment_input.device,
363
+ suppress_tokens=suppress_tokens,
364
+ kwargs=kwargs,
365
+ )
366
+
367
+ # 6.6 set max new tokens or max length
368
+ self._set_max_new_tokens_and_length(
369
+ config=self.config,
370
+ decoder_input_ids=decoder_input_ids,
371
+ generation_config=generation_config,
372
+ )
373
+
374
+ # 6.7 Set current `begin_index` for all logit processors
375
+ for proc in logits_processor:
376
+ if hasattr(proc, "set_begin_index"):
377
+ proc.set_begin_index(decoder_input_ids.shape[-1])
378
+
379
+ # 6.8 Run generate with fallback
380
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(
381
+ segment_input=segment_input,
382
+ decoder_input_ids=decoder_input_ids,
383
+ cur_bsz=cur_bsz,
384
+ batch_idx_map=batch_idx_map,
385
+ seek=seek,
386
+ num_segment_frames=num_segment_frames,
387
+ max_frames=max_frames,
388
+ temperatures=temperatures,
389
+ generation_config=generation_config,
390
+ logits_processor=logits_processor,
391
+ stopping_criteria=stopping_criteria,
392
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
393
+ synced_gpus=synced_gpus,
394
+ return_token_timestamps=return_token_timestamps,
395
+ do_condition_on_prev_tokens=do_condition_on_prev_tokens,
396
+ kwargs=kwargs,
397
+ )
398
+
399
+ # 6.9 In every generated sequence, split by timestamp tokens and extract segments
400
+ if not self.config.is_mt or self.config.mt_num_speakers == 1:
401
+ for i, seek_sequence in enumerate(seek_sequences):
402
+ prev_i = batch_idx_map[i]
403
+
404
+ if should_skip[i]:
405
+ seek[prev_i] += seek_num_frames[prev_i]
406
+ continue
407
+
408
+ segments, segment_offset = self._retrieve_segment(
409
+ seek_sequence=seek_sequence,
410
+ seek_outputs=seek_outputs,
411
+ time_offset=time_offset,
412
+ timestamp_begin=timestamp_begin,
413
+ seek_num_frames=seek_num_frames,
414
+ time_precision=time_precision,
415
+ input_stride=input_stride,
416
+ prev_idx=prev_i,
417
+ idx=i,
418
+ return_token_timestamps=return_token_timestamps,
419
+ )
420
+
421
+ current_segments[prev_i] += segments
422
+ seek[prev_i] += segment_offset
423
+ else:
424
+ # We have to make sure all speakers are synchronized thus we have to find minumum of seeks that each instance like
425
+ for j, seek_seqs in enumerate(
426
+ [seek_sequences[i * self.config.mt_num_speakers:(i + 1) * self.config.mt_num_speakers] for i in
427
+ range(len(seek_sequences) // self.config.mt_num_speakers)]):
428
+ indexes = [j * self.config.mt_num_speakers + i for i in range(self.config.mt_num_speakers)]
429
+ prev_ids = [batch_idx_map[i] for i in indexes]
430
+
431
+ if all([should_skip[i] for i in indexes]):
432
+ for i, prev_i in zip(indexes, prev_ids):
433
+ seek[prev_i] += seek_num_frames[prev_i]
434
+ continue
435
+
436
+ segments, segment_offset = self._retrieve_segment_mt(
437
+ seek_sequences=seek_seqs,
438
+ seek_outputs=seek_outputs,
439
+ time_offset=time_offset,
440
+ timestamp_begin=timestamp_begin,
441
+ seek_num_frames=seek_num_frames,
442
+ time_precision=time_precision,
443
+ input_stride=input_stride,
444
+ prev_ids=prev_ids,
445
+ ids=indexes,
446
+ return_token_timestamps=return_token_timestamps,
447
+ )
448
+ if self.config.uses_enrollments:
449
+ segment_offset[1:] = [torch.tensor(0)] *len(segment_offset[1:])
450
+ else:
451
+ segment_offset[1:] = [segment_offset[0]] * len(segment_offset[1:])
452
+
453
+ for prev_i, i in zip(prev_ids, range(self.config.mt_num_speakers)):
454
+ current_segments[prev_i] += segments[i]
455
+ seek[prev_i] += segment_offset[i]
456
+
457
+ if self.config.uses_enrollments:
458
+ if seek[prev_ids[0]] >= max_frames[prev_ids[0]]:
459
+ seek[prev_ids[1]] = max_frames[prev_ids[1]]
460
+
461
+
462
+ # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
463
+ # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
464
+ final_segments = (
465
+ [x[1:] for x in current_segments]
466
+ if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
467
+ else current_segments
468
+ )
469
+ if "is_valid" in kwargs:
470
+ final_segments = [seg for idx, seg in enumerate(final_segments) if kwargs['is_valid'][idx]]
471
+ sequences = _pad_to_max_length(
472
+ final_segments, generation_config.pad_token_id, device=self.device, padding="right"
473
+ )
474
+
475
+ # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
476
+ output = {"sequences": sequences, "segments": final_segments}
477
+
478
+ self.encoder_logits = None
479
+
480
+ if isinstance(output, dict):
481
+ output = self._fix_timestamps_from_segmentation(output)
482
+
483
+ return output
484
+
485
+ @staticmethod
486
+ def _find_common_seek(sequences, seeks):
487
+ """
488
+ Finds the minimum seek that does not overlap with other sequences,
489
+ and falls back to (segment.start - 0.2) if needed. Assumes:
490
+ - 'seeks' is a list of (seek_time_int, sequence_index),
491
+ - seek_time_int is in timestamp * 100 format (e.g., 125.5s -> 12550).
492
+ """
493
+
494
+ def is_valid_seek(seek_time, exclude_seq_idx):
495
+ for idx, seq in enumerate(sequences):
496
+ if idx == exclude_seq_idx:
497
+ continue
498
+ for segment in seq:
499
+ start = getattr(segment, 'start', segment['start'])
500
+ end = getattr(segment, 'end', segment['end'])
501
+ if seek_time < start:
502
+ break # Segments are sorted by end
503
+ if start < seek_time < end:
504
+ return False
505
+ return True
506
+
507
+ # Step 1: Find minimum seek
508
+ # if all seek values are the same, return it immediately
509
+ seeks = [s if isinstance(s, int) else s.item() for s in seeks]
510
+ if len(set(seeks)) == 1:
511
+ return seeks[0]
512
+
513
+ min_seek_val = min(seeks)
514
+ min_seek_idx = seeks.index(min_seek_val)
515
+ min_seek_real = min_seek_val / 100
516
+
517
+ if is_valid_seek(min_seek_real, min_seek_idx):
518
+ return min_seek_val
519
+
520
+ # Step 2: Try fallback seeks from all sequences (segment.start - 0.1s)
521
+ fallback_seeks = set()
522
+ for idx, seq in enumerate(sequences):
523
+ for segment in seq:
524
+ start = getattr(segment, 'start', segment['start'])
525
+ if isinstance(start, torch.Tensor):
526
+ start = start.item()
527
+ candidate = round(start, 2)
528
+ fallback_seeks.add((candidate, idx, True))
529
+ end = getattr(segment, 'end', segment['end'])
530
+ if isinstance(end, torch.Tensor):
531
+ end = end.item()
532
+ if end < min_seek_real:
533
+ candidate = round(end, 2)
534
+ fallback_seeks.add((candidate, idx, True))
535
+
536
+ valid_fallbacks = [
537
+ (int(s * 100), idx, is_start) for s, idx, is_start in fallback_seeks
538
+ if is_valid_seek(s, min_seek_idx)
539
+ ]
540
+
541
+ if valid_fallbacks:
542
+ return max(valid_fallbacks)
543
+
544
+ # Step 3: Nothing valid
545
+ return 0
546
+
547
+ @staticmethod
548
+ def remove_segments_after_seek(sequences, seek, eps=100):
549
+ """
550
+ Keep only segments that finish before given timestamp.
551
+
552
+ Args:
553
+ sequences: List of lists, each containing segments (dict or object with 'start' and 'end').
554
+ seek: Integer seek timestamp (e.g., timestamp * 100).
555
+
556
+ Returns:
557
+ None. Modifies the sequences in-place.
558
+ """
559
+ return [[seg for seg in seq if (getattr(seg, 'end', seg['end']) * 100 <= seek + eps)] for seq in sequences]
560
+
561
+ @staticmethod
562
+ def _retrieve_segment_wo_seek(
563
+ seek_sequence,
564
+ seek_outputs,
565
+ time_offset,
566
+ timestamp_begin,
567
+ seek_num_frames,
568
+ time_precision,
569
+ input_stride,
570
+ prev_idx,
571
+ idx,
572
+ return_token_timestamps,
573
+ ):
574
+ # find the predicted "end of segment" predictions of Whisper
575
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
576
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
577
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
578
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
579
+ timestamp_segment_indices.add_(1)
580
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
581
+
582
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
583
+ # "end of segment" prediction and slice the decoding into segments accordingly
584
+ if len(timestamp_segment_indices) > 0:
585
+ # if the output contains two consecutive timestamp tokens
586
+ slices = timestamp_segment_indices.tolist()
587
+ segments = []
588
+ if single_timestamp_ending:
589
+ slices.append(len(seek_sequence))
590
+
591
+ last_slice = 0
592
+ # Add each segment to list of all segments
593
+ for current_slice in slices:
594
+ sliced_tokens = seek_sequence[last_slice:current_slice]
595
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
596
+ end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
597
+ segments.append(
598
+ {
599
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
600
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
601
+ "tokens": sliced_tokens,
602
+ "result": seek_outputs[idx],
603
+ }
604
+ )
605
+ if return_token_timestamps:
606
+ segments[-1]["token_timestamps"] = (
607
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
608
+ )
609
+ last_slice = current_slice
610
+
611
+ if not single_timestamp_ending:
612
+ # generate all predictions after the last predicted "end of segment" and seek by 30s
613
+ sliced_tokens = seek_sequence[last_slice:]
614
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
615
+ end_timestamp_pos = seek_num_frames[prev_idx] // 2
616
+ segments.append(
617
+ {
618
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
619
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
620
+ "tokens": sliced_tokens,
621
+ "result": seek_outputs[idx],
622
+ }
623
+ )
624
+ segment_offset = seek_num_frames[prev_idx]
625
+ else:
626
+ # If whisper does not predict any "end of segment" token, then
627
+ # the whole decoding is considered a segment and we add it to the list of segments
628
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
629
+ start_timestamp_pos = 0.0
630
+ last_timestamp_pos = seek_num_frames[prev_idx] // 2
631
+
632
+ if timestamps.numel() > 1:
633
+ start_timestamp_pos = timestamps[-2].item() - timestamp_begin
634
+ last_timestamp_pos = timestamps[-1].item() - timestamp_begin
635
+ elif timestamps.numel() == 1:
636
+ # no consecutive timestamps but it has a timestamp; use the last one.
637
+ start_timestamp_pos = timestamps[-1].item() - timestamp_begin
638
+ segments = [
639
+ {
640
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
641
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
642
+ "tokens": seek_sequence,
643
+ "result": seek_outputs[idx],
644
+ }
645
+ ]
646
+
647
+ segment_offset = seek_num_frames[prev_idx]
648
+
649
+ return segments, segment_offset
650
+
651
+ def _retrieve_segment_mt(
652
+ self,
653
+ seek_sequences,
654
+ seek_outputs,
655
+ time_offset,
656
+ timestamp_begin,
657
+ seek_num_frames,
658
+ time_precision,
659
+ input_stride,
660
+ prev_ids,
661
+ ids,
662
+ return_token_timestamps,
663
+ ):
664
+ sequences, seeks = [], []
665
+ for sequence, prev_id, idx in zip(seek_sequences, prev_ids, ids):
666
+ seq, seek = self._retrieve_segment(
667
+ seek_sequence=sequence,
668
+ seek_outputs=seek_outputs,
669
+ time_offset=time_offset,
670
+ timestamp_begin=timestamp_begin,
671
+ seek_num_frames=seek_num_frames,
672
+ time_precision=time_precision,
673
+ input_stride=input_stride,
674
+ prev_idx=prev_id,
675
+ idx=idx,
676
+ return_token_timestamps=return_token_timestamps,
677
+ )
678
+ sequences.append(seq)
679
+ seeks.append(seek)
680
+ return sequences, seeks
681
+
682
+ def _beam_search(
683
+ self,
684
+ input_ids: torch.LongTensor,
685
+ beam_scorer: BeamScorer,
686
+ logits_processor: LogitsProcessorList,
687
+ stopping_criteria: StoppingCriteriaList,
688
+ generation_config: GenerationConfig,
689
+ synced_gpus: bool,
690
+ logits_warper: Optional[LogitsProcessorList] = None,
691
+ **model_kwargs,
692
+ ) -> Union[GenerateBeamOutput, torch.LongTensor]:
693
+ r"""
694
+ Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
695
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
696
+
697
+ Parameters:
698
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
699
+ The sequence used as a prompt for the generation.
700
+ beam_scorer (`BeamScorer`):
701
+ An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
702
+ sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
703
+ logits_processor (`LogitsProcessorList`):
704
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
705
+ used to modify the prediction scores of the language modeling head applied at each generation step.
706
+ stopping_criteria (`StoppingCriteriaList`:
707
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
708
+ used to tell if the generation loop should stop.
709
+ generation_config ([`~generation.GenerationConfig`]):
710
+ The generation configuration to be used as parametrization of the decoding method.
711
+ synced_gpus (`bool`):
712
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
713
+ logits_warper (`LogitsProcessorList`, *optional*):
714
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
715
+ to warp the prediction score distribution of the language modeling head applied before multinomial
716
+ sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
717
+ `generation_config`)
718
+ model_kwargs:
719
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
720
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
721
+
722
+ Return:
723
+ [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
724
+ `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
725
+ [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
726
+ `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
727
+ `model.config.is_encoder_decoder=True`.
728
+ """
729
+ # init values
730
+ pad_token_id = generation_config.pad_token_id
731
+ eos_token_id = generation_config.eos_token_id
732
+ output_attentions = generation_config.output_attentions
733
+ output_hidden_states = generation_config.output_hidden_states
734
+ output_scores = generation_config.output_scores
735
+ output_logits = generation_config.output_logits
736
+ return_dict_in_generate = generation_config.return_dict_in_generate
737
+ sequential = generation_config.low_memory
738
+ do_sample = generation_config.do_sample
739
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
740
+ raise ValueError(
741
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
742
+ f"{logits_warper})."
743
+ )
744
+
745
+ beam_scorer._beam_hyps = beam_scorer._beam_hyps[:self.encoder_logits.shape[0]]
746
+
747
+ batch_size = len(beam_scorer._beam_hyps)
748
+ num_beams = beam_scorer.num_beams
749
+
750
+ batch_beam_size, cur_len = input_ids.shape
751
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
752
+
753
+ if num_beams * batch_size != batch_beam_size:
754
+ raise ValueError(
755
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
756
+ )
757
+
758
+ # init attention / hidden states / scores tuples
759
+ scores = () if (return_dict_in_generate and output_scores) else None
760
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
761
+ beam_indices = (
762
+ tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
763
+ )
764
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
765
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
766
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
767
+
768
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
769
+ if return_dict_in_generate and self.config.is_encoder_decoder:
770
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
771
+ encoder_hidden_states = (
772
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
773
+ )
774
+
775
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
776
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
777
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
778
+ beam_scores[:, 1:] = -1e9
779
+ beam_scores = beam_scores.view((batch_size * num_beams,))
780
+
781
+ this_peer_finished = False
782
+
783
+ decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
784
+
785
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
786
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
787
+
788
+ # if sequential is True, split the input to batches of batch_size and run sequentially
789
+ if sequential:
790
+ if any(
791
+ model_name in self.__class__.__name__.lower()
792
+ for model_name in [
793
+ "fsmt",
794
+ "reformer",
795
+ "bloom",
796
+ "ctrl",
797
+ "gpt_bigcode",
798
+ "transo_xl",
799
+ "xlnet",
800
+ "cpm",
801
+ "jamba",
802
+ ]
803
+ ):
804
+ raise RuntimeError(
805
+ f"Currently generation for {self.__class__.__name__} is not supported "
806
+ f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
807
+ )
808
+
809
+ inputs_per_sub_batches = _split_model_inputs(
810
+ model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
811
+ )
812
+ outputs_per_sub_batch = [
813
+ self(
814
+ **inputs_per_sub_batch,
815
+ return_dict=True,
816
+ output_attentions=output_attentions,
817
+ output_hidden_states=output_hidden_states,
818
+ )
819
+ for inputs_per_sub_batch in inputs_per_sub_batches
820
+ ]
821
+
822
+ outputs = stack_model_outputs(outputs_per_sub_batch)
823
+
824
+ else: # Unchanged original behavior
825
+ outputs = self(
826
+ **model_inputs,
827
+ return_dict=True,
828
+ output_attentions=output_attentions,
829
+ output_hidden_states=output_hidden_states,
830
+ )
831
+
832
+ if synced_gpus and this_peer_finished:
833
+ cur_len = cur_len + 1
834
+ continue # don't waste resources running the code we don't need
835
+
836
+ next_token_logits = outputs.logits[:, -1, :]
837
+ next_token_scores = nn.functional.log_softmax(
838
+ next_token_logits, dim=-1
839
+ ) # (batch_size * num_beams, vocab_size)
840
+
841
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
842
+ if do_sample:
843
+ next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
844
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
845
+ next_token_scores_processed
846
+ )
847
+
848
+ # Store scores, attentions and hidden_states when required
849
+ if return_dict_in_generate:
850
+ if output_scores:
851
+ scores += (next_token_scores_processed,)
852
+ if output_logits:
853
+ raw_logits += (next_token_logits,)
854
+ if output_attentions:
855
+ decoder_attentions += (
856
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
857
+ )
858
+ if self.config.is_encoder_decoder:
859
+ cross_attentions += (outputs.cross_attentions,)
860
+ if output_hidden_states:
861
+ decoder_hidden_states += (
862
+ (outputs.decoder_hidden_states,)
863
+ if self.config.is_encoder_decoder
864
+ else (outputs.hidden_states,)
865
+ )
866
+
867
+ # reshape for beam search
868
+ vocab_size = next_token_scores.shape[-1]
869
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
870
+
871
+ # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
872
+ # non eos token per beam.
873
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
874
+ n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
875
+ if do_sample:
876
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
877
+ next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
878
+ next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
879
+ next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
880
+ next_tokens = torch.gather(next_tokens, -1, _indices)
881
+ else:
882
+ next_token_scores, next_tokens = torch.topk(
883
+ next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
884
+ )
885
+
886
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
887
+ next_tokens = next_tokens % vocab_size
888
+
889
+ # stateless
890
+ beam_outputs = beam_scorer.process(
891
+ input_ids,
892
+ next_token_scores,
893
+ next_tokens,
894
+ next_indices,
895
+ pad_token_id=pad_token_id,
896
+ eos_token_id=eos_token_id,
897
+ beam_indices=beam_indices,
898
+ decoder_prompt_len=decoder_prompt_len,
899
+ )
900
+
901
+ beam_scores = beam_outputs["next_beam_scores"]
902
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
903
+ beam_idx = beam_outputs["next_beam_indices"]
904
+
905
+ # Based on the beam idx and next tokens reshuffle the ctc prev states and scores
906
+ if hasattr(self, "ctc_rescorer"):
907
+ self.ctc_rescorer.update_state(beam_next_tokens, beam_idx)
908
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
909
+
910
+ model_kwargs = self._update_model_kwargs_for_generation(
911
+ outputs,
912
+ model_kwargs,
913
+ is_encoder_decoder=self.config.is_encoder_decoder,
914
+ )
915
+ if model_kwargs.get("past_key_values", None) is not None:
916
+ model_kwargs["past_key_values"] = self._temporary_reorder_cache(
917
+ model_kwargs["past_key_values"], beam_idx
918
+ )
919
+
920
+ if return_dict_in_generate and output_scores:
921
+ beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
922
+
923
+ # increase cur_len
924
+ cur_len = cur_len + 1
925
+
926
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
927
+ this_peer_finished = True
928
+
929
+ sequence_outputs = beam_scorer.finalize(
930
+ input_ids,
931
+ beam_scores,
932
+ next_tokens,
933
+ next_indices,
934
+ pad_token_id=pad_token_id,
935
+ eos_token_id=eos_token_id,
936
+ max_length=stopping_criteria.max_length,
937
+ beam_indices=beam_indices,
938
+ decoder_prompt_len=decoder_prompt_len,
939
+ )
940
+
941
+ if return_dict_in_generate:
942
+ if not output_scores:
943
+ sequence_outputs["sequence_scores"] = None
944
+
945
+ if self.config.is_encoder_decoder:
946
+ return GenerateBeamEncoderDecoderOutput(
947
+ sequences=sequence_outputs["sequences"],
948
+ sequences_scores=sequence_outputs["sequence_scores"],
949
+ scores=scores,
950
+ logits=raw_logits,
951
+ beam_indices=sequence_outputs["beam_indices"],
952
+ encoder_attentions=encoder_attentions,
953
+ encoder_hidden_states=encoder_hidden_states,
954
+ decoder_attentions=decoder_attentions,
955
+ cross_attentions=cross_attentions,
956
+ decoder_hidden_states=decoder_hidden_states,
957
+ past_key_values=model_kwargs.get("past_key_values"),
958
+ )
959
+ else:
960
+ return GenerateBeamDecoderOnlyOutput(
961
+ sequences=sequence_outputs["sequences"],
962
+ sequences_scores=sequence_outputs["sequence_scores"],
963
+ scores=scores,
964
+ logits=raw_logits,
965
+ beam_indices=sequence_outputs["beam_indices"],
966
+ attentions=decoder_attentions,
967
+ hidden_states=decoder_hidden_states,
968
+ past_key_values=model_kwargs.get("past_key_values"),
969
+ )
970
+ else:
971
+ return sequence_outputs["sequences"]
972
+
973
+ def _sample(
974
+ self,
975
+ input_ids: torch.LongTensor,
976
+ logits_processor: LogitsProcessorList,
977
+ stopping_criteria: StoppingCriteriaList,
978
+ generation_config: GenerationConfig,
979
+ synced_gpus: bool,
980
+ streamer: Optional["BaseStreamer"],
981
+ logits_warper: Optional[LogitsProcessorList] = None,
982
+ **model_kwargs,
983
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
984
+ r"""
985
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
986
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
987
+
988
+ Parameters:
989
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
990
+ The sequence used as a prompt for the generation.
991
+ logits_processor (`LogitsProcessorList`):
992
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
993
+ used to modify the prediction scores of the language modeling head applied at each generation step.
994
+ stopping_criteria (`StoppingCriteriaList`):
995
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
996
+ used to tell if the generation loop should stop.
997
+ generation_config ([`~generation.GenerationConfig`]):
998
+ The generation configuration to be used as parametrization of the decoding method.
999
+ synced_gpus (`bool`):
1000
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1001
+ streamer (`BaseStreamer`, *optional*):
1002
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
1003
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
1004
+ logits_warper (`LogitsProcessorList`, *optional*):
1005
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
1006
+ to warp the prediction score distribution of the language modeling head applied before multinomial
1007
+ sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
1008
+ `generation_config`)
1009
+ model_kwargs:
1010
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
1011
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
1012
+
1013
+ Return:
1014
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
1015
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
1016
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
1017
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
1018
+ `model.config.is_encoder_decoder=True`.
1019
+ """
1020
+ # init values
1021
+ pad_token_id = generation_config.pad_token_id
1022
+ output_attentions = generation_config.output_attentions
1023
+ output_hidden_states = generation_config.output_hidden_states
1024
+ output_scores = generation_config.output_scores
1025
+ output_logits = generation_config.output_logits
1026
+ return_dict_in_generate = generation_config.return_dict_in_generate
1027
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
1028
+ do_sample = generation_config.do_sample
1029
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
1030
+ raise ValueError(
1031
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
1032
+ f"{logits_warper})."
1033
+ )
1034
+
1035
+ # init attention / hidden states / scores tuples
1036
+ scores = () if (return_dict_in_generate and output_scores) else None
1037
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
1038
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1039
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1040
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
1041
+
1042
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
1043
+ if return_dict_in_generate and self.config.is_encoder_decoder:
1044
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
1045
+ encoder_hidden_states = (
1046
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
1047
+ )
1048
+
1049
+ # keep track of which sequences are already finished
1050
+ batch_size = input_ids.shape[0]
1051
+ this_peer_finished = False
1052
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1053
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
1054
+
1055
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
1056
+ # prepare model inputs
1057
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1058
+
1059
+ # forward pass to get next token
1060
+ outputs = self(
1061
+ **model_inputs,
1062
+ return_dict=True,
1063
+ output_attentions=output_attentions,
1064
+ output_hidden_states=output_hidden_states,
1065
+ )
1066
+
1067
+ if synced_gpus and this_peer_finished:
1068
+ continue # don't waste resources running the code we don't need
1069
+
1070
+ next_token_logits = outputs.logits[:, -1, :]
1071
+
1072
+ # pre-process distribution
1073
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1074
+ if do_sample:
1075
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1076
+
1077
+ # Store scores, attentions and hidden_states when required
1078
+ if return_dict_in_generate:
1079
+ if output_scores:
1080
+ scores += (next_token_scores,)
1081
+ if output_logits:
1082
+ raw_logits += (next_token_logits,)
1083
+ if output_attentions:
1084
+ decoder_attentions += (
1085
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
1086
+ )
1087
+ if self.config.is_encoder_decoder:
1088
+ cross_attentions += (outputs.cross_attentions,)
1089
+
1090
+ if output_hidden_states:
1091
+ decoder_hidden_states += (
1092
+ (outputs.decoder_hidden_states,)
1093
+ if self.config.is_encoder_decoder
1094
+ else (outputs.hidden_states,)
1095
+ )
1096
+
1097
+ # token selection
1098
+ if do_sample:
1099
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1100
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1101
+ else:
1102
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
1103
+
1104
+ # finished sentences should have their next token be a padding token
1105
+ if has_eos_stopping_criteria:
1106
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1107
+
1108
+ # Based on the next tokens select the ctc prev states and scores
1109
+ if hasattr(self, "ctc_rescorer"):
1110
+ self.ctc_rescorer.update_state(next_tokens, torch.arange(next_tokens.shape[0]))
1111
+
1112
+ # update generated ids, model inputs, and length for next step
1113
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1114
+ if streamer is not None:
1115
+ streamer.put(next_tokens.cpu())
1116
+ model_kwargs = self._update_model_kwargs_for_generation(
1117
+ outputs,
1118
+ model_kwargs,
1119
+ is_encoder_decoder=self.config.is_encoder_decoder,
1120
+ )
1121
+
1122
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
1123
+ this_peer_finished = unfinished_sequences.max() == 0
1124
+
1125
+ if streamer is not None:
1126
+ streamer.end()
1127
+
1128
+ if return_dict_in_generate:
1129
+ if self.config.is_encoder_decoder:
1130
+ return GenerateEncoderDecoderOutput(
1131
+ sequences=input_ids,
1132
+ scores=scores,
1133
+ logits=raw_logits,
1134
+ encoder_attentions=encoder_attentions,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ decoder_attentions=decoder_attentions,
1137
+ cross_attentions=cross_attentions,
1138
+ decoder_hidden_states=decoder_hidden_states,
1139
+ past_key_values=model_kwargs.get("past_key_values"),
1140
+ )
1141
+ else:
1142
+ return GenerateDecoderOnlyOutput(
1143
+ sequences=input_ids,
1144
+ scores=scores,
1145
+ logits=raw_logits,
1146
+ attentions=decoder_attentions,
1147
+ hidden_states=decoder_hidden_states,
1148
+ past_key_values=model_kwargs.get("past_key_values"),
1149
+ )
1150
+ else:
1151
+ return input_ids
1152
+
1153
+ def prepare_kwargs_for_generate(self,
1154
+ segment_input,
1155
+ cur_bsz,
1156
+ batch_idx_map,
1157
+ seek,
1158
+ num_segment_frames,
1159
+ max_frames,
1160
+ kwargs):
1161
+ kwargs["attention_mask_enc"] = torch.ones(cur_bsz, segment_input.size(-1), device=segment_input.device)
1162
+ seek_vad = seek // 2
1163
+ num_frames_vad = num_segment_frames // 2
1164
+ max_frames_vad = max_frames // 2
1165
+ seek_num_frames = (max_frames_vad - seek_vad).clamp(max=num_frames_vad)
1166
+
1167
+ stno_masks = []
1168
+ for i in range(cur_bsz):
1169
+ prev_i = batch_idx_map[i]
1170
+ segment_input_slice = kwargs["stno_mask"][prev_i: prev_i + 1, :,
1171
+ seek_vad[prev_i]: seek_vad[prev_i] + seek_num_frames[prev_i]]
1172
+
1173
+ if segment_input_slice.shape[-1] < num_frames_vad:
1174
+ orig_len = segment_input_slice.shape[-1]
1175
+ # pad to 3000 if necessary
1176
+ segment_input_slice = torch.nn.functional.pad(
1177
+ segment_input_slice, pad=(0, num_frames_vad - orig_len)
1178
+ )
1179
+ # set corresponding padding tokens to 1 in vad mask representing silence
1180
+ segment_input_slice[0, 0, orig_len:] = 1.0
1181
+
1182
+ stno_masks.append(segment_input_slice)
1183
+ kwargs["stno_mask"] = torch.cat(stno_masks, dim=0)
1184
+ self.stno_mask_seek = kwargs["stno_mask"]
1185
+
1186
+ if "per_group_sizes" in kwargs:
1187
+ group_sizes = kwargs["per_group_sizes"].clone()
1188
+ group_sizes[:] = 0
1189
+ cummulative_group_sizes = (
1190
+ kwargs["per_group_sizes"].max().repeat(kwargs["per_group_sizes"].shape[0])).cumsum(dim=0)
1191
+ for i in batch_idx_map:
1192
+ group_idx = (cummulative_group_sizes > i).nonzero().min()
1193
+ group_sizes[group_idx] += 1
1194
+ kwargs["per_group_sizes"] = group_sizes
1195
+
1196
+ if self.vad_seek_callback is not None:
1197
+ self.vad_seek_callback(kwargs["stno_mask"])
1198
+ if "is_valid" in kwargs:
1199
+ kwargs['is_valid'] = kwargs["is_valid"][batch_idx_map]
1200
+ kwargs['labels'] = kwargs["labels"][batch_idx_map]
1201
+ kwargs['upp_labels'] = kwargs["upp_labels"][batch_idx_map]
1202
+ return kwargs
1203
+
1204
+ def generate_with_fallback(
1205
+ self,
1206
+ segment_input,
1207
+ decoder_input_ids,
1208
+ cur_bsz,
1209
+ batch_idx_map,
1210
+ seek,
1211
+ num_segment_frames,
1212
+ max_frames,
1213
+ temperatures,
1214
+ generation_config,
1215
+ logits_processor,
1216
+ stopping_criteria,
1217
+ prefix_allowed_tokens_fn,
1218
+ synced_gpus,
1219
+ return_token_timestamps,
1220
+ do_condition_on_prev_tokens,
1221
+ kwargs,
1222
+ ):
1223
+ kwargs = copy.copy(kwargs)
1224
+ kwargs = self.prepare_kwargs_for_generate(segment_input, cur_bsz, batch_idx_map, seek, num_segment_frames,
1225
+ max_frames, kwargs)
1226
+ seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = super().generate_with_fallback(
1227
+ segment_input,
1228
+ decoder_input_ids,
1229
+ cur_bsz,
1230
+ batch_idx_map,
1231
+ seek,
1232
+ num_segment_frames,
1233
+ max_frames,
1234
+ temperatures,
1235
+ generation_config,
1236
+ logits_processor,
1237
+ stopping_criteria,
1238
+ prefix_allowed_tokens_fn,
1239
+ synced_gpus,
1240
+ return_token_timestamps,
1241
+ do_condition_on_prev_tokens,
1242
+ kwargs,
1243
+ )
1244
+ self.stno_mask_seek = None
1245
+
1246
+ if "is_valid" in kwargs:
1247
+ seek_sequences_tmp = [torch.tensor([])] * len(seek_sequences)
1248
+ seek_outputs_tmp = [torch.tensor([])] * len(seek_sequences)
1249
+ should_skip_tmp = [False] * len(seek_sequences)
1250
+ do_condition_on_prev_tokens_tmp = [None] * len(seek_sequences)
1251
+
1252
+ non_valid_inc = 0
1253
+ for idx, is_valid in enumerate(kwargs["is_valid"]):
1254
+ if is_valid:
1255
+ seek_sequences_tmp[idx] = seek_sequences[non_valid_inc]
1256
+ seek_outputs_tmp[idx] = seek_outputs[non_valid_inc]
1257
+ should_skip_tmp[idx] = should_skip[non_valid_inc]
1258
+ do_condition_on_prev_tokens_tmp[idx] = do_condition_on_prev_tokens[non_valid_inc]
1259
+ non_valid_inc+= 1
1260
+ seek_sequences = seek_sequences_tmp
1261
+ seek_outputs = seek_outputs_tmp
1262
+ should_skip = should_skip_tmp
1263
+ do_condition_on_prev_tokens = do_condition_on_prev_tokens_tmp
1264
+
1265
+
1266
+ # for i, seq in enumerate(seek_outputs):
1267
+ # print(f"Sequence {i} {self.safe_tokenizer_decode(kwargs['labels'][batch_idx_map[i]])}: {self.tokenizer.decode(seq, decode_with_timestamps=True)}")
1268
+ # print("-"*50)
1269
+
1270
+ return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens
1271
+
1272
+ def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
1273
+ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
1274
+ """short function to replace num with a itr in lst"""
1275
+ found = any(i in lst for i in itr)
1276
+ if found:
1277
+ lst = [num if i in itr else i for i in lst]
1278
+ else:
1279
+ lst.append(num)
1280
+ return lst
1281
+
1282
+ def language_to_id(language: str) -> int:
1283
+ language = language.lower()
1284
+ if language in generation_config.lang_to_id.keys():
1285
+ language_token = language
1286
+ elif language in TO_LANGUAGE_CODE.keys():
1287
+ language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
1288
+ elif language in TO_LANGUAGE_CODE.values():
1289
+ language_token = f"<|{language}|>"
1290
+ else:
1291
+ is_language_code = len(language) == 2
1292
+ raise ValueError(
1293
+ f"Unsupported language: {language}. Language should be one of:"
1294
+ f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
1295
+ )
1296
+ if language_token not in generation_config.lang_to_id:
1297
+ raise ValueError(
1298
+ f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
1299
+ "(You should just add it to the generation config)"
1300
+ )
1301
+
1302
+ return generation_config.lang_to_id[language_token]
1303
+
1304
+ task = getattr(generation_config, "task", None)
1305
+ language = getattr(generation_config, "language", None)
1306
+
1307
+ forced_decoder_ids = generation_config.forced_decoder_ids
1308
+ if forced_decoder_ids is not None:
1309
+ if language is None and task is None and forced_decoder_ids[0][1] is None:
1310
+ logger.warning_once(
1311
+ "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
1312
+ "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
1313
+ )
1314
+ elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
1315
+ forced_decoder_ids = config.forced_decoder_ids
1316
+
1317
+ elif forced_decoder_ids is not None and language is not None:
1318
+ logger.info(
1319
+ f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
1320
+ )
1321
+ forced_decoder_ids = None
1322
+
1323
+ init_tokens = [generation_config.decoder_start_token_id]
1324
+
1325
+ # Update init_tokens with languages
1326
+ lang_ids = None
1327
+
1328
+ if forced_decoder_ids is not None:
1329
+ return forced_decoder_ids
1330
+
1331
+ # from v4.39 the forced decoder ids are always None in favour of decoder input ids
1332
+ generation_config.forced_decoder_ids = None
1333
+
1334
+ is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
1335
+
1336
+ # Make sure language is a list of strings of the correct length
1337
+ if isinstance(language, (list, tuple)):
1338
+ if any(l is None for l in language):
1339
+ raise TypeError(
1340
+ "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
1341
+ )
1342
+ if len(language) != batch_size:
1343
+ raise ValueError(
1344
+ "When passing a list of languages, the length of the list must match the batch size. "
1345
+ f"Expected length of {batch_size}, but got {len(language)} languages."
1346
+ )
1347
+ languages = language
1348
+ elif language is None:
1349
+ # Language will be detected for each item in batch
1350
+ languages = [None] * batch_size
1351
+ else:
1352
+ languages = [language] # Use a length-1 list now, broadcast later
1353
+
1354
+ # Separate init_tokens for each language
1355
+ init_tokens = [copy.copy(init_tokens) for _ in languages]
1356
+
1357
+ if language is not None and lang_ids is not None:
1358
+ lang_ids = [language_to_id(l) for l in languages]
1359
+ elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
1360
+ # language is not defined or intentially set to `None` to trigger language detection
1361
+ lang_ids = self.detect_language(
1362
+ input_features=input_features,
1363
+ encoder_outputs=kwargs.get("encoder_outputs", None),
1364
+ generation_config=generation_config,
1365
+ num_segment_frames=num_segment_frames,
1366
+ ).tolist()
1367
+ if lang_ids is not None:
1368
+ # append or replace lang_ids to init_tokens
1369
+ for i in range(len(init_tokens)):
1370
+ if len(init_tokens[i]) > 1:
1371
+ init_tokens[i][1] = lang_ids[i]
1372
+ else:
1373
+ init_tokens[i].append(lang_ids[i])
1374
+ del languages
1375
+
1376
+ # Update init_tokens with task
1377
+ for i in range(len(init_tokens)):
1378
+ if task is not None:
1379
+ if task in TASK_IDS:
1380
+ init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1381
+ task_id = generation_config.task_to_id[generation_config.task]
1382
+
1383
+ # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
1384
+ replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
1385
+ else:
1386
+ raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
1387
+ elif language is not None and hasattr(generation_config, "task_to_id"):
1388
+ # if language is defined, but no task id is in `init_tokens`, default to transcribe
1389
+ if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
1390
+ init_tokens[i].append(generation_config.task_to_id["transcribe"])
1391
+
1392
+ # let's make sure we don't pass `None` tokens as prompt tokens
1393
+ init_tokens[i] = [t for t in init_tokens[i] if t is not None]
1394
+
1395
+ return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
1396
+
1397
+ def detect_language(
1398
+ self,
1399
+ input_features: Optional[torch.FloatTensor] = None,
1400
+ encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
1401
+ generation_config: Optional[GenerationConfig] = None,
1402
+ num_segment_frames: int = 3000,
1403
+ ) -> torch.Tensor:
1404
+ """
1405
+ Detects language from log-mel input features or encoder_outputs
1406
+
1407
+ Parameters:
1408
+ input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
1409
+ Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
1410
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
1411
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
1412
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
1413
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
1414
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1415
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
1416
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
1417
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1418
+ generation_config (`~generation.GenerationConfig`, *optional*):
1419
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1420
+ passed to generate matching the attributes of `generation_config` will override them. If
1421
+ `generation_config` is not provided, the default will be used, which had the following loading
1422
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1423
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1424
+ default values, whose documentation should be checked to parameterize generation.
1425
+ num_segment_frames (`int`, defaults to 3000):
1426
+ The number of log-mel frames the model expects
1427
+
1428
+ Return:
1429
+ A `torch.LongTensor` representing the detected language ids.
1430
+ """
1431
+ if input_features is None and encoder_outputs is None:
1432
+ raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
1433
+ elif input_features is not None and encoder_outputs is not None:
1434
+ raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
1435
+ elif input_features is not None:
1436
+ inputs = {"input_features": input_features[:, :, :num_segment_frames]}
1437
+ batch_size = input_features.shape[0]
1438
+ elif encoder_outputs is not None:
1439
+ inputs = {"encoder_outputs": encoder_outputs}
1440
+ batch_size = (
1441
+ encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
1442
+ )
1443
+
1444
+ generation_config = generation_config or self.generation_config
1445
+ decoder_input_ids = (
1446
+ torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
1447
+ * generation_config.decoder_start_token_id
1448
+ )
1449
+
1450
+ with torch.no_grad():
1451
+ logits = self(**inputs, decoder_input_ids=decoder_input_ids,
1452
+ stno_mask=self.stno_mask[:, :, :num_segment_frames // 2]).logits[:, -1]
1453
+
1454
+ non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
1455
+ non_lang_mask[list(generation_config.lang_to_id.values())] = False
1456
+
1457
+ logits[:, non_lang_mask] = -np.inf
1458
+
1459
+ lang_ids = logits.argmax(-1)
1460
+
1461
+ return lang_ids
1462
+
1463
+ def _get_logits_processor(
1464
+ self,
1465
+ generation_config: GenerationConfig,
1466
+ input_ids_seq_length: int,
1467
+ encoder_input_ids: torch.LongTensor,
1468
+ prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
1469
+ logits_processor: Optional[LogitsProcessorList],
1470
+ device: str = None,
1471
+ model_kwargs: Optional[Dict[str, Any]] = None,
1472
+ negative_prompt_ids: Optional[torch.Tensor] = None,
1473
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
1474
+ ) -> LogitsProcessorList:
1475
+ # pylint: disable=no-member
1476
+ gen_config_copy = copy.deepcopy(generation_config)
1477
+ gen_config_copy.forced_decoder_ids = None
1478
+ processors = super()._get_logits_processor(
1479
+ gen_config_copy,
1480
+ input_ids_seq_length,
1481
+ encoder_input_ids,
1482
+ prefix_allowed_tokens_fn,
1483
+ logits_processor,
1484
+ device,
1485
+ model_kwargs,
1486
+ negative_prompt_ids,
1487
+ negative_prompt_attention_mask,
1488
+ )
1489
+ if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
1490
+ enc_logits = self.encoder_logits
1491
+ if generation_config.num_beams <= 1:
1492
+ processors.append(LogSoftmaxProcessor())
1493
+ else:
1494
+ enc_logits = enc_logits.repeat_interleave(generation_config.num_beams, dim=0)
1495
+ self.ctc_rescorer = CTCRescorerLogitsProcessor(
1496
+ enc_logits,
1497
+ torch.full((enc_logits.shape[0],), fill_value=enc_logits.shape[1],
1498
+ device=enc_logits.device),
1499
+ enc_logits.shape[-1] - 1,
1500
+ generation_config.pad_token_id.item(),
1501
+ generation_config.eos_token_id.item(),
1502
+ generation_config.decoder_start_token_id.item(),
1503
+ self.tokenizer,
1504
+ generation_config.ctc_margin,
1505
+ generation_config.ctc_weight,
1506
+ generation_config.num_beams,
1507
+ False,
1508
+ )
1509
+ processors.append(self.ctc_rescorer)
1510
+ return processors
1511
+
1512
+ def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, is_shortform, num_beams,
1513
+ device):
1514
+ if generation_config.return_timestamps is True:
1515
+ timestamp_processor = WhisperTimeStampLogitsProcessorCustom(generation_config, begin_index=begin_index)
1516
+ logits_processor = (
1517
+ [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
1518
+ )
1519
+
1520
+ if generation_config.suppress_tokens is not None:
1521
+ suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
1522
+ logits_processor = (
1523
+ [suppress_tokens_processor]
1524
+ if logits_processor is None
1525
+ else [suppress_tokens_processor] + logits_processor
1526
+ )
1527
+ generation_config.suppress_tokens = None
1528
+
1529
+ if generation_config.begin_suppress_tokens is not None:
1530
+ begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
1531
+ generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
1532
+ )
1533
+ logits_processor = (
1534
+ [begin_suppress_processor]
1535
+ if logits_processor is None
1536
+ else [begin_suppress_processor] + logits_processor
1537
+ )
1538
+ generation_config.begin_suppress_tokens = None
1539
+
1540
+ if generation_config.no_speech_threshold is not None and not is_shortform:
1541
+ no_speech_detector = WhisperNoSpeechDetection(
1542
+ no_speech_token=generation_config.no_timestamps_token_id - 1,
1543
+ begin_index=begin_index,
1544
+ scores_is_logprobs=num_beams > 1,
1545
+ )
1546
+ logits_processor = (
1547
+ [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
1548
+ )
1549
+ no_speech_detector.set_model(self)
1550
+
1551
+ return logits_processor
1552
+
1553
+ @staticmethod
1554
+ def round_to_nearest_0_02(x):
1555
+ d = Decimal(str(x)) # Use str(x) to preserve input precision
1556
+ step = Decimal('0.02')
1557
+ # Divide, round, multiply back
1558
+ rounded = (d / step).to_integral_value(rounding=ROUND_HALF_UP) * step
1559
+ return rounded
1560
+
1561
+ def _fix_timestamps_from_segmentation(self, sequences):
1562
+ """
1563
+ Adjusts token sequences with global timestamps to fit within Whisper's 0–30s timestamp token range.
1564
+
1565
+ This function modifies the input sequences by inserting appropriate timestamp tokens and
1566
+ offset corrections to ensure the decoded token order is correct, without splitting any segment.
1567
+ It aligns all timestamps to 0.02-second precision, inserts placeholder segments to bridge
1568
+ time gaps between 30-second windows, and maintains segment continuity during encoding.
1569
+
1570
+ Args:
1571
+ sequences (dict): A dictionary containing:
1572
+ - 'segments': A list of segment lists, each segment being a dict with 'start', 'end', and 'tokens'.
1573
+ - 'sequences': A tensor used to determine device for padding.
1574
+
1575
+ Returns:
1576
+ torch.Tensor: A batch of padded token sequences with corrected timestamp alignment.
1577
+ """
1578
+ # Get the token ID for the "<|0.00|>" timestamp used to detect dummy segments
1579
+ first_timestamp_token = self.tokenizer.get_vocab()["<|0.00|>"]
1580
+ results = []
1581
+
1582
+ # Filter out segments that are either empty or consist only of the "<|0.00|>" token
1583
+ for idx, sequence_segs in enumerate(sequences['segments']):
1584
+ sequences['segments'][idx] = [
1585
+ seg for seg in sequence_segs
1586
+ if len(seg['tokens']) > 0 and (len(seg['tokens']) != 1 or seg['tokens'][0] != first_timestamp_token)
1587
+ ]
1588
+
1589
+ # Iterate over each group of segments (e.g., one per utterance)
1590
+ for idx, sequence_segs in enumerate(sequences['segments']):
1591
+ result = []
1592
+ prev_segment_end_time = None
1593
+ correction = Decimal(0.0)
1594
+
1595
+ for i, seg in enumerate(sequence_segs):
1596
+ # Round start and end times to nearest 0.02 seconds
1597
+ start_time = self.round_to_nearest_0_02(seg['start'].item())
1598
+ end_time = self.round_to_nearest_0_02(seg['end'].item())
1599
+ tokens = seg['tokens']
1600
+
1601
+ # Determine which 30s window this segment falls into
1602
+ current_block = (start_time + correction) // 30
1603
+
1604
+ if prev_segment_end_time is not None:
1605
+ # If not the first segment, calculate difference in 30s windows
1606
+ prev_block = prev_segment_end_time // 30
1607
+ num_dummies = current_block - prev_block - 1
1608
+
1609
+ # Insert (30, [], 30) marker if we're moving to a new block
1610
+ if current_block > prev_block:
1611
+ result.append((30, [], 30))
1612
+
1613
+ # Insert dummy segments to bridge skipped 30s blocks
1614
+ for _ in range(int(num_dummies)):
1615
+ result.append((0, [], 30))
1616
+ else:
1617
+ # For the first segment, add dummy blocks if it starts after 30s
1618
+ for _ in range(int(start_time // 30)):
1619
+ result.append((0, [], 30))
1620
+
1621
+ # Determine whether segment fits in one block or wraps to the next
1622
+ if (start_time + correction) // 30 == (end_time + correction) // 30:
1623
+ # Segment fits within a single 30s window
1624
+ result.append(((start_time + correction) % 30, tokens, (end_time + correction) % 30))
1625
+ else:
1626
+ # Segment would wrap across a 30s boundary
1627
+ new_seg_start = (correction + start_time) % 30
1628
+ new_seg_end = end_time - start_time
1629
+
1630
+ if new_seg_end >= new_seg_start:
1631
+ # Seek back to the beginning of the segment window
1632
+ result.append((new_seg_start, [], new_seg_start))
1633
+ result.append((0, tokens, new_seg_end))
1634
+ # Apply correction to align future timestamps to new 30s block
1635
+ correction = self.round_to_nearest_0_02(-(start_time % 30))
1636
+ else:
1637
+ # Otherwise, just insert with adjusted times
1638
+ result.append((new_seg_start, tokens, new_seg_end))
1639
+ correction = self.round_to_nearest_0_02(30 - (start_time % 30))
1640
+ # print(f'Processed segment {i}, result: {self.tokenizer.decode(self.tokenizer("".join([f"<|{seg[0]:.2f}|>{self.tokenizer.decode(seg[1])}<|{seg[2]:.2f}|>" for seg in result]))["input_ids"], decode_with_timestamps=True)[-250:]}')
1641
+ # Update the previous segment's end time for next iteration
1642
+ prev_segment_end_time = end_time + correction
1643
+
1644
+ # Convert result segments into a token sequence with proper timestamp formatting
1645
+ encoded = self.tokenizer(
1646
+ "".join([f"<|{seg[0]:.2f}|>{self.tokenizer.decode(seg[1])}<|{seg[2]:.2f}|>" for seg in result])
1647
+ )['input_ids']
1648
+ results.append(encoded)
1649
+
1650
+ # Pad all sequences to the same length for batching
1651
+ sequences = pad_sequence(
1652
+ [torch.tensor(res, device=sequences['sequences'].device) for res in results],
1653
+ batch_first=True,
1654
+ padding_value=self.tokenizer.pad_token_id
1655
+ )
1656
+ return sequences
1657
+
1658
+ @staticmethod
1659
+ def _retrieve_segment(
1660
+ seek_sequence,
1661
+ seek_outputs,
1662
+ time_offset,
1663
+ timestamp_begin,
1664
+ seek_num_frames,
1665
+ time_precision,
1666
+ input_stride,
1667
+ prev_idx,
1668
+ idx,
1669
+ return_token_timestamps,
1670
+ ):
1671
+ # find the predicted "end of segment" predictions of Whisper
1672
+ # "end of segment" predictions occur whenever Whisper predicts a timestamp token
1673
+ timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
1674
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
1675
+ timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
1676
+ timestamp_segment_indices.add_(1)
1677
+ token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
1678
+
1679
+ # If whisper predicted a "end of segment" via a timestep token, let's go ever each
1680
+ # "end of segment" prediction and slice the decoding into segments accordingly
1681
+ if len(timestamp_segment_indices) > 0:
1682
+ # if the output contains two consecutive timestamp tokens
1683
+ slices = timestamp_segment_indices.tolist()
1684
+ segments = []
1685
+ if single_timestamp_ending:
1686
+ slices.append(len(seek_sequence))
1687
+
1688
+ last_slice = 0
1689
+ # Add each segment to list of all segments
1690
+ for current_slice in slices:
1691
+ sliced_tokens = seek_sequence[last_slice:current_slice]
1692
+ start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
1693
+ end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
1694
+ segments.append(
1695
+ {
1696
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
1697
+ "end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
1698
+ "tokens": sliced_tokens,
1699
+ "result": seek_outputs[idx],
1700
+ }
1701
+ )
1702
+ if return_token_timestamps:
1703
+ segments[-1]["token_timestamps"] = (
1704
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
1705
+ )
1706
+ last_slice = current_slice
1707
+
1708
+ if single_timestamp_ending:
1709
+ # single timestamp at the end means no speech after the last timestamp.
1710
+ segment_offset = seek_num_frames[prev_idx]
1711
+ else:
1712
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
1713
+ # here we throw away all predictions after the last predicted "end of segment"
1714
+ # since we are cutting right in the middle of an audio
1715
+ last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
1716
+ segment_offset = last_timestamp_pos * input_stride
1717
+ else:
1718
+ # If whisper does not predict any "end of segment" token, then
1719
+ # the whole decoding is considered a segment and we add it to the list of segments
1720
+ timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
1721
+ start_timestamp_pos = 0.0
1722
+ last_timestamp_pos = seek_num_frames[prev_idx] // 2
1723
+ skip = False
1724
+ segment_offset = seek_num_frames[prev_idx]
1725
+
1726
+ if timestamps.numel() > 1:
1727
+ start_timestamp_pos = timestamps[-2].item() - timestamp_begin
1728
+ last_timestamp_pos = timestamps[-1].item() - timestamp_begin
1729
+ elif timestamps.numel() == 1:
1730
+ # no consecutive timestamps but it has a timestamp; use the last one.
1731
+ start_timestamp_pos = timestamps[-1].item() - timestamp_begin
1732
+ if start_timestamp_pos > 200:
1733
+ # segment does not fit into decoding window, so we need to rollback
1734
+ segment_offset = start_timestamp_pos * input_stride - 100 # timestamp might be inaccurate
1735
+ skip = True
1736
+ else:
1737
+ # empty sequence, or sequence w/o timestamps
1738
+ skip = True
1739
+
1740
+ if skip:
1741
+ segments = []
1742
+ else:
1743
+ segments = [
1744
+ {
1745
+ "start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
1746
+ "end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
1747
+ "tokens": seek_sequence,
1748
+ "result": seek_outputs[idx],
1749
+ }
1750
+ ]
1751
+ if return_token_timestamps:
1752
+ segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
1753
+ segment_offset = seek_num_frames[prev_idx]
1754
+
1755
+ if segment_offset <= 0:
1756
+ msg = f"Timestamps: {timestamps}, Segments: {segments}"
1757
+ raise ValueError(f"Segment offset: {segment_offset} <= 0. This should not happen!\n{msg}")
1758
+
1759
+ return segments, segment_offset
1760
+
1761
+ def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config):
1762
+ # remove all previously passed decoder input ids
1763
+ if isinstance(seek_outputs, torch.Tensor):
1764
+ seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1]:]
1765
+ seek_outputs = torch.hstack((
1766
+ seek_outputs,
1767
+ torch.full((seek_outputs.shape[0], 1),
1768
+ fill_value=generation_config.pad_token_id,
1769
+ dtype=seek_outputs.dtype,
1770
+ device=seek_outputs.device
1771
+ )
1772
+ ))
1773
+ # first_eos = (seek_outputs == generation_config.eos_token_id).int().argmax(dim=1)
1774
+ # biggest_timestamp = generation_config.no_timestamps_token_id + 1 + 30 * 50
1775
+
1776
+ # empty_transcriptions = first_eos == 0
1777
+ # seek_outputs[empty_transcriptions, 0] = generation_config.no_timestamps_token_id + 1 # 0.00 timestamp
1778
+ # seek_outputs[empty_transcriptions, 1] = biggest_timestamp # 30.00 timestamp
1779
+ # seek_outputs[empty_transcriptions, 2] = generation_config.eos_token_id # 30.00 timestamp
1780
+
1781
+ return seek_outputs, seek_outputs
1782
+
1783
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
1784
+ num_frames = getattr(generation_config, "num_frames", None)
1785
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
1786
+ seek_outputs, generation_config.alignment_heads, num_frames=num_frames
1787
+ )
1788
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1]:]
1789
+
1790
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1]:]
1791
+
1792
+ def split_by_batch_index(values, key, batch_idx):
1793
+ if key == "scores":
1794
+ return [v[batch_idx].cpu() for v in values]
1795
+ elif key == "past_key_values":
1796
+ # we don't save `past_key_values` as this is too costly
1797
+ return None
1798
+ elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
1799
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
1800
+ return values[batch_idx].cpu()
1801
+
1802
+ sequence_tokens = seek_outputs["sequences"]
1803
+ seek_outputs = [
1804
+ {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
1805
+ for i in range(sequence_tokens.shape[0])
1806
+ ]
1807
+
1808
+ return sequence_tokens, seek_outputs
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "begin_suppress_tokens": [
4
+ 220,
5
+ 50256
6
+ ],
7
+ "bos_token_id": 50257,
8
+ "decoder_start_token_id": 50258,
9
+ "eos_token_id": 50257,
10
+ "pad_token_id": 50257,
11
+ "transformers_version": "4.42.0"
12
+ }
layers.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class CustomLinear(nn.Linear):
7
+ def __init__(self, *args, init_eye_val=0.0, is_diagonal=False, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.init_eye_val = init_eye_val
10
+
11
+ class CustomLinearInitialized(nn.Linear):
12
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
13
+ device=None, dtype=None, init_fun=None) -> None:
14
+ super().__init__(in_features, out_features, bias, device, dtype)
15
+ self.init_fun = init_fun
16
+
17
+ class CustomDiagonalLinear(nn.Module):
18
+ def __init__(self, d_model, bias=True, init_eye_val=0.0):
19
+ super().__init__()
20
+ self.init_eye_val = init_eye_val
21
+ self.weight = nn.Parameter(torch.full((d_model,), init_eye_val))
22
+ self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None
23
+
24
+ def forward(self, input):
25
+ out = input * self.weight
26
+ if self.bias is not None:
27
+ out += self.bias
28
+ return out
29
+
30
+ class Gate(nn.Module):
31
+ def __init__(self, items, init_val=0.0):
32
+ super().__init__()
33
+ self.init_val = init_val
34
+ self.gate = nn.Parameter(torch.full((items,), init_val))
35
+
36
+ def forward(self, input, dim):
37
+ if input.ndim != 4:
38
+ raise ValueError('input must be a 4D tensor')
39
+ shape = [1] * 4
40
+ shape[dim] = -1
41
+ return input * self.gate.view(*shape)
42
+
43
+
44
+ class AttentivePoolingClassifier(nn.Module):
45
+ def __init__(self, d_model, num_classes, hidden_dim=128):
46
+ """
47
+ Attentive Pooling Classifier
48
+
49
+ Args:
50
+ d_model: Input feature dimension (D)
51
+ num_classes: Number of output classes (V)
52
+ hidden_dim: Hidden dimension for attention mechanism
53
+ """
54
+ super(AttentivePoolingClassifier, self).__init__()
55
+
56
+ # Attention mechanism for pooling [B,T,D] -> [B,D]
57
+ self.attention_projection = nn.Linear(d_model, hidden_dim)
58
+ self.attention_weights = nn.Linear(hidden_dim, 1)
59
+
60
+ # Classifier [B,D] -> [B,V]
61
+ self.classifier = nn.Sequential(
62
+ nn.Linear(d_model, hidden_dim),
63
+ nn.ReLU(),
64
+ nn.Dropout(0.1),
65
+ nn.Linear(hidden_dim, num_classes)
66
+ )
67
+
68
+ def forward(self, x, apply_stop_gradient=True):
69
+ """
70
+ Forward pass
71
+
72
+ Args:
73
+ x: Input tensor of shape [B, T, D]
74
+ apply_stop_gradient: Whether to apply stop gradient
75
+
76
+ Returns:
77
+ logits: Output logits [B, V]
78
+ attention_weights: Attention weights [B, T]
79
+ pooled_features: Pooled features [B, D]
80
+ """
81
+ # Apply stop gradient if specified
82
+ if apply_stop_gradient:
83
+ x = x.detach()
84
+
85
+ # Compute attention weights
86
+ # x: [B, T, D] -> [B, T, hidden_dim]
87
+ att_proj = torch.tanh(self.attention_projection(x))
88
+
89
+ # att_proj: [B, T, hidden_dim] -> [B, T, 1] -> [B, T]
90
+ attention_scores = self.attention_weights(att_proj).squeeze(-1)
91
+ attention_weights = F.softmax(attention_scores, dim=-1)
92
+
93
+ # Apply attentive pooling: [B, T, D] * [B, T, 1] -> [B, D]
94
+ pooled_features = torch.sum(x * attention_weights.unsqueeze(-1), dim=1)
95
+
96
+ # Classification
97
+ logits = self.classifier(pooled_features)
98
+
99
+ return logits
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67601a48c8342a5e8aa5e4542892906703d203fd8ce8fb5009860b72dffe4adc
3
+ size 4672829976
modeling_dicow.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import CrossEntropyLoss
6
+ import torch.utils.checkpoint
7
+ import torch.utils.checkpoint
8
+ from transformers.modeling_outputs import Seq2SeqLMOutput
9
+ from transformers.models.speech_encoder_decoder.modeling_speech_encoder_decoder import (
10
+ shift_tokens_right,
11
+ )
12
+ from transformers.models.whisper.modeling_whisper import (
13
+ WhisperEncoder,
14
+ )
15
+ from transformers.models.whisper.modeling_whisper import (
16
+ WhisperForConditionalGeneration,
17
+ shift_tokens_right,
18
+ WhisperModel,
19
+ )
20
+ from transformers.models.whisper.modeling_whisper import sinusoids
21
+ from transformers.utils import logging
22
+
23
+ from .config import Seq2SeqLMOutputLosses, Seq2SeqModelOutputLogit, DiCoWConfig
24
+ from .encoder import DiCoWEncoder
25
+ from .FDDT import FDDT
26
+ from .layers import CustomLinear, CustomDiagonalLinear, Gate, AttentivePoolingClassifier, CustomLinearInitialized
27
+ from .generation import DiCoWGenerationMixin
28
+ from .contrastive_loss import ContrastiveLoss
29
+ import wandb
30
+
31
+ logging.set_verbosity_debug()
32
+ logger = logging.get_logger("transformers")
33
+
34
+
35
+ class DiCoW(WhisperModel):
36
+ def __init__(self, config: DiCoWConfig):
37
+ super().__init__(config)
38
+ self.encoder = DiCoWEncoder(config)
39
+
40
+ def forward(
41
+ self,
42
+ input_features: Optional[torch.FloatTensor] = None,
43
+ attention_mask: Optional[torch.LongTensor] = None,
44
+ decoder_input_ids: Optional[torch.LongTensor] = None,
45
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
46
+ head_mask: Optional[torch.Tensor] = None,
47
+ decoder_head_mask: Optional[torch.Tensor] = None,
48
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
49
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
50
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
51
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
52
+ decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
53
+ use_cache: Optional[bool] = None,
54
+ output_attentions: Optional[bool] = None,
55
+ output_hidden_states: Optional[bool] = None,
56
+ return_dict: Optional[bool] = None,
57
+ stno_mask: Optional[torch.FloatTensor] = None,
58
+ per_group_sizes: Optional[torch.LongTensor] = None,
59
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutputLosses]:
60
+ r"""
61
+ Returns:
62
+
63
+ Example:
64
+ ```python
65
+ >>> import torch
66
+ >>> from transformers import AutoFeatureExtractor, WhisperModel
67
+ >>> from datasets import load_dataset
68
+
69
+ >>> model = WhisperModel.from_pretrained("openai/whisper-base")
70
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
71
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
72
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
73
+ >>> input_features = inputs.input_features
74
+ >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
75
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
76
+ >>> list(last_hidden_state.shape)
77
+ [1, 2, 512]
78
+ ```"""
79
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
80
+ output_hidden_states = (
81
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
82
+ )
83
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
84
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
85
+
86
+ if encoder_outputs is None:
87
+ input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
88
+
89
+ encoder_outputs = self.encoder(
90
+ input_features,
91
+ output_attentions=output_attentions,
92
+ output_hidden_states=True,
93
+ head_mask=head_mask,
94
+ return_dict=return_dict,
95
+ stno_mask=stno_mask,
96
+ per_group_sizes=per_group_sizes
97
+ )
98
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
99
+ # elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
100
+ # raise ValueError("encoder_outputs should be of type BaseModelOutput when return_dict=True.")
101
+
102
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
103
+ decoder_outputs = self.decoder(
104
+ input_ids=decoder_input_ids,
105
+ attention_mask=decoder_attention_mask,
106
+ encoder_hidden_states=encoder_outputs.hidden_states[-1],
107
+ head_mask=decoder_head_mask,
108
+ cross_attn_head_mask=cross_attn_head_mask,
109
+ past_key_values=past_key_values,
110
+ inputs_embeds=decoder_inputs_embeds,
111
+ position_ids=decoder_position_ids,
112
+ use_cache=use_cache,
113
+ output_attentions=output_attentions,
114
+ output_hidden_states=output_hidden_states,
115
+ return_dict=return_dict,
116
+ )
117
+
118
+ if not return_dict:
119
+ return decoder_outputs + encoder_outputs
120
+
121
+ return Seq2SeqModelOutputLogit(
122
+ last_hidden_state=decoder_outputs.last_hidden_state,
123
+ past_key_values=decoder_outputs.past_key_values,
124
+ decoder_hidden_states=decoder_outputs.hidden_states,
125
+ decoder_attentions=decoder_outputs.attentions,
126
+ cross_attentions=decoder_outputs.cross_attentions,
127
+ encoder_last_hidden_state=encoder_outputs.hidden_states[-1],
128
+ encoder_hidden_states=encoder_outputs.hidden_states,
129
+ encoder_attentions=encoder_outputs.attentions,
130
+ encoder_logits=encoder_outputs.logits,
131
+ )
132
+
133
+
134
+ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration):
135
+ config_class = DiCoWConfig
136
+
137
+ def __init__(self, config: DiCoWConfig):
138
+ super().__init__(config)
139
+ self.model = DiCoW(config)
140
+ self.encoder_logits = None
141
+ self.tokenizer = None
142
+ self.vad_seek_callback = None
143
+ self.stno_mask = None
144
+ self.stno_mask_seek = None
145
+ self.use_enrollment_network = config.use_enrollment_network
146
+ if self.config.contrastive_loss_weight > 0.0:
147
+ self.contrastive_loss_fct = ContrastiveLoss(distance_metric="cosine")
148
+ self.sid_classifier = nn.Linear(config.d_model, config.num_speakers)
149
+ # self.sid_classifier = AttentivePoolingClassifier(config.d_model, config.num_speakers, config.d_model // 4)
150
+ self.embedding_projector = nn.Linear(config.d_model, config.d_model)
151
+
152
+ # We need this setter as we can't pass a function/method as a config argument.
153
+ # JSON serialization fails at that point.
154
+ def set_vad_seek_callback(self, vad_seek_callback):
155
+ self.vad_seek_callback = vad_seek_callback
156
+
157
+ def set_tokenizer(self, tokenizer):
158
+ self.tokenizer = tokenizer
159
+
160
+ def _init_weights(self, module):
161
+ std = self.config.init_std
162
+ fddt_init = self.config.fddt_init
163
+ if isinstance(module, CustomLinearInitialized):
164
+ module.init_fun(module)
165
+ elif isinstance(module, CustomLinear):
166
+ with torch.no_grad():
167
+ if fddt_init == 'random':
168
+ module.weight.data.normal_(mean=0.0, std=std)
169
+ if module.bias is not None:
170
+ module.bias.data.normal_(mean=0.0, std=std)
171
+ elif fddt_init == 'non-disturbing':
172
+ module.weight.data = torch.eye(*module.weight.shape).data
173
+ if module.bias is not None:
174
+ module.bias.data.zero_()
175
+ elif fddt_init == 'disparagement':
176
+ eye = torch.eye(*module.weight.shape)
177
+ eye *= module.init_eye_val
178
+ module.weight.data = eye.data
179
+ if module.bias is not None:
180
+ module.bias.data.zero_()
181
+ elif isinstance(module, CustomDiagonalLinear):
182
+ with torch.no_grad():
183
+ if fddt_init == 'random':
184
+ module.weight.data.normal_(mean=0.0, std=std)
185
+ if module.bias is not None:
186
+ module.bias.data.normal_(mean=0.0, std=std)
187
+ elif fddt_init == 'non-disturbing':
188
+ module.weight.data = torch.ones_like(module.weight.data).data
189
+ if module.bias is not None:
190
+ module.bias.data.zero_()
191
+ elif fddt_init == 'disparagement':
192
+ module.weight.data = module.init_eye_val * torch.ones_like(module.weight.data).data
193
+ if module.bias is not None:
194
+ module.bias.data.zero_()
195
+ elif isinstance(module, FDDT):
196
+ if module.bias_only:
197
+ if fddt_init == 'random':
198
+ module.target_linear.data.normal_(mean=0.0, std=std)
199
+ module.non_target_linear.data.normal_(mean=0.0, std=std)
200
+ module.overlap_linear.data.normal_(mean=0.0, std=std)
201
+ module.silence_linear.data.normal_(mean=0.0, std=std)
202
+ module.scb.data.normal_(mean=0.0, std=std)
203
+ else:
204
+ module.target_linear.data.zero_()
205
+ module.non_target_linear.data.zero_()
206
+ module.overlap_linear.data.zero_()
207
+ module.silence_linear.data.zero_()
208
+ module.scb.data.zero_()
209
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
210
+ module.weight.data.normal_(mean=0.0, std=std)
211
+ if module.bias is not None:
212
+ module.bias.data.zero_()
213
+ elif isinstance(module, nn.Embedding):
214
+ module.weight.data.normal_(mean=0.0, std=std)
215
+ if module.padding_idx is not None:
216
+ module.weight.data[module.padding_idx].zero_()
217
+ elif isinstance(module, WhisperEncoder):
218
+ with torch.no_grad():
219
+ embed_positions = module.embed_positions.weight
220
+ embed_positions.copy_(sinusoids(*embed_positions.shape))
221
+ elif isinstance(module, nn.LayerNorm):
222
+ module.reset_parameters()
223
+ elif isinstance(module, nn.MultiheadAttention):
224
+ module._reset_parameters()
225
+ elif isinstance(module, nn.ConvTranspose1d):
226
+ module.reset_parameters()
227
+ elif isinstance(module, Gate):
228
+ module.gate.data = module.init_val * torch.ones_like(module.gate.data).data
229
+
230
+ def forward(
231
+ self,
232
+ input_features: Optional[torch.FloatTensor] = None,
233
+ stno_mask: Optional[torch.FloatTensor] = None,
234
+ per_group_sizes: Optional[torch.LongTensor] = None,
235
+ attention_mask_enc: Optional[torch.LongTensor] = None,
236
+ attention_mask: Optional[torch.LongTensor] = None,
237
+ decoder_input_ids: Optional[torch.LongTensor] = None,
238
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
239
+ head_mask: Optional[torch.Tensor] = None,
240
+ decoder_head_mask: Optional[torch.Tensor] = None,
241
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
242
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
243
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
244
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
245
+ decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
246
+ labels: Optional[torch.LongTensor] = None,
247
+ upp_labels: Optional[torch.LongTensor] = None,
248
+ use_cache: Optional[bool] = None,
249
+ output_attentions: Optional[bool] = None,
250
+ output_hidden_states: Optional[bool] = None,
251
+ return_dict: Optional[bool] = None,
252
+ is_valid: Optional[bool] = None,
253
+ spk_id: Optional[torch.LongTensor] = None,
254
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
255
+ r"""
256
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
257
+ Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
258
+ or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
259
+ only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
260
+
261
+ Returns:
262
+
263
+ Example:
264
+
265
+ ```python
266
+ >>> import torch
267
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
268
+ >>> from datasets import load_dataset
269
+
270
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
271
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
272
+
273
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
274
+
275
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
276
+ >>> input_features = inputs.input_features
277
+
278
+ >>> generated_ids = model.generate(inputs=input_features)
279
+
280
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
281
+ >>> transcription
282
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
283
+ ```"""
284
+ stno_mask_orig = stno_mask
285
+ enrollments_processed = None
286
+ enroll_stno_mask_reshape = None
287
+ enrollments_enc = None
288
+ if self.training and self.use_enrollment_network:
289
+ attention_mask = attention_mask[::2, ...]
290
+
291
+ enroll_input = input_features[1::2, ...]
292
+ input_features = input_features[::2, ...]
293
+
294
+ is_valid = is_valid[::2, ...]
295
+ enroll_stno_mask = stno_mask[1::2, ...]
296
+ stno_mask = stno_mask[::2, ...]
297
+
298
+ labels = labels[::2, ...]
299
+ upp_labels = upp_labels[::2, ...]
300
+ enrollments_enc = self.model.encoder.encode_enrollment(
301
+ input_features=enroll_input,
302
+ num_layers_to_apply=self.config.spk_embedding_extraction_layer,
303
+ head_mask=head_mask,
304
+ stno_mask=enroll_stno_mask,
305
+ )
306
+ enroll_stno_mask_reshape = ((enroll_stno_mask[:, 1, :] + enroll_stno_mask[:, 3, :]) > 0.5).view(-1,
307
+ self.config.mt_num_speakers,
308
+ enroll_stno_mask.shape[
309
+ 2]).flatten(1,
310
+ 2)
311
+ enrollments_processed = enrollments_enc.view(-1, self.config.mt_num_speakers, enrollments_enc.shape[1],
312
+ enrollments_enc.shape[2]).flatten(1, 2)
313
+
314
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
315
+
316
+ if labels is not None:
317
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
318
+ decoder_input_ids = shift_tokens_right(
319
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
320
+ )
321
+
322
+ outputs = self.model(
323
+ input_features,
324
+ attention_mask=attention_mask,
325
+ decoder_input_ids=decoder_input_ids,
326
+ encoder_outputs=encoder_outputs,
327
+ decoder_attention_mask=decoder_attention_mask,
328
+ head_mask=head_mask,
329
+ decoder_head_mask=decoder_head_mask,
330
+ cross_attn_head_mask=cross_attn_head_mask,
331
+ past_key_values=past_key_values,
332
+ decoder_inputs_embeds=decoder_inputs_embeds,
333
+ decoder_position_ids=decoder_position_ids,
334
+ use_cache=use_cache,
335
+ output_attentions=output_attentions,
336
+ output_hidden_states=output_hidden_states,
337
+ return_dict=return_dict,
338
+ stno_mask=stno_mask,
339
+ per_group_sizes=per_group_sizes
340
+ )
341
+
342
+ dec_lm_logits = self.proj_out(outputs.last_hidden_state)
343
+ enc_lm_logits = outputs.encoder_logits
344
+
345
+ loss = None
346
+ ctc_loss = 0
347
+
348
+ # remove fake inputs from labels and logits given per group sizes
349
+ if is_valid is not None:
350
+ if self.config.ctc_weight > 0.0:
351
+ enc_lm_logits = enc_lm_logits[is_valid]
352
+ dec_lm_logits = dec_lm_logits[is_valid]
353
+ labels = labels[is_valid]
354
+ upp_labels = upp_labels[is_valid]
355
+ if labels is not None and self.config.ctc_weight > 0.0:
356
+ enc_labels = labels.clone()
357
+ for token in self.tokenizer.prefix_tokens:
358
+ if (enc_labels[:, 0] == token).all():
359
+ enc_labels = enc_labels[:, 1:]
360
+ enc_labels[enc_labels == self.config.eos_token_id] = -100
361
+
362
+ ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels)
363
+
364
+ if labels is not None:
365
+ loss_fct = CrossEntropyLoss(reduction='none')
366
+ # move labels to correct device to enable PP
367
+ labels = labels.to(dec_lm_logits.device)
368
+ dec_loss1 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
369
+ dec_loss2 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), upp_labels.reshape(-1))
370
+ dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean()
371
+ if wandb.run is not None:
372
+ wandb.log({"dec_loss": dec_loss})
373
+ wandb.log({"ctc_loss": ctc_loss})
374
+ loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
375
+
376
+ if hasattr(self, "contrastive_loss_fct"):
377
+ stno_per_spk_pair = stno_mask.view(-1, self.config.mt_num_speakers, stno_mask.shape[1],
378
+ stno_mask.shape[2])
379
+ anchors = ((stno_per_spk_pair[:, :, 1, :] + stno_per_spk_pair[:, :, 3, :]) > 0.5).flatten(1)
380
+ intermediate_states = outputs.encoder_hidden_states[self.config.spk_embedding_extraction_layer].view(-1,
381
+ self.config.mt_num_speakers,
382
+ stno_mask.shape[
383
+ 2],
384
+ outputs.encoder_hidden_states[
385
+ self.config.spk_embedding_extraction_layer].shape[
386
+ -1]).flatten(
387
+ 1, 2)
388
+ valid_pairs = is_valid.view((-1, self.config.mt_num_speakers)).all(dim=-1)
389
+
390
+ contrastive_loss = self.contrastive_loss_fct(
391
+ self.embedding_projector(intermediate_states[valid_pairs]),
392
+ anchors[valid_pairs],
393
+ self.embedding_projector(enrollments_processed[valid_pairs]) if enrollments_processed is not None else None,
394
+ enroll_stno_mask_reshape[valid_pairs] if enroll_stno_mask_reshape is not None else None
395
+ )
396
+ if wandb.run is not None:
397
+ wandb.log({"contrastive_loss": contrastive_loss})
398
+ loss += self.config.contrastive_loss_weight * contrastive_loss
399
+
400
+ embeds = outputs.encoder_hidden_states[self.config.spk_embedding_extraction_layer]
401
+ all_embeds = torch.empty((embeds.shape[0] * 2, embeds.shape[1], embeds.shape[2]), dtype=embeds.dtype,
402
+ device=embeds.device)
403
+ all_embeds[::2] = embeds
404
+ all_embeds[1::2] = enrollments_enc
405
+ spk_logits = self.sid_classifier(self.embedding_projector(all_embeds))
406
+ spk_id_mask = (stno_mask_orig[:, 1] + stno_mask_orig[:, 3]) > 0.5
407
+ spk_loss_fun = CrossEntropyLoss(reduction='mean')
408
+ spk_labels = spk_id[:,None].repeat((1, spk_logits.shape[1]))[spk_id_mask]
409
+ spk_loss = spk_loss_fun(spk_logits[spk_id_mask], spk_labels)
410
+ if wandb.run is not None:
411
+ spk_id_acc = (torch.argmax(spk_logits[spk_id_mask], dim=-1) == spk_labels).sum() / len(spk_labels[spk_labels!=-100])
412
+ wandb.log({"spk_loss": spk_loss, "spk_id_acc": spk_id_acc})
413
+ loss += spk_loss
414
+
415
+ if not return_dict:
416
+ output = (dec_lm_logits,) + outputs[1:]
417
+ return ((loss,) + output) if loss is not None else output
418
+
419
+ return Seq2SeqLMOutputLosses(
420
+ loss=loss,
421
+ logits=dec_lm_logits,
422
+ past_key_values=outputs.past_key_values,
423
+ decoder_hidden_states=outputs.decoder_hidden_states,
424
+ decoder_attentions=outputs.decoder_attentions,
425
+ cross_attentions=outputs.cross_attentions,
426
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
427
+ encoder_hidden_states=outputs.encoder_hidden_states,
428
+ encoder_attentions=outputs.encoder_attentions,
429
+ encoder_logits=enc_lm_logits,
430
+ )
431
+
432
+ def _get_feat_extract_output_lengths(self, attention_mask: torch.Tensor) -> torch.Tensor:
433
+ return (self.model.encoder._get_feat_extract_output_lengths(attention_mask) / 4).ceil()
434
+
435
+ def freeze_except(self, prefixes_to_preheat):
436
+ for name, param in self.named_parameters():
437
+ param.requires_grad = False
438
+ for prefix in prefixes_to_preheat:
439
+ if name.startswith(prefix):
440
+ param.requires_grad = True
441
+
442
+ def suppress_interactions(self):
443
+ """This method suppress final projection in CoAttention blocks to let the original information flow through"""
444
+ for name, param in self.named_parameters():
445
+ if "interaction" in name and "cat_proj" in name:
446
+ with torch.no_grad():
447
+ if "bias" in name:
448
+ param[:] = 0.
449
+ else:
450
+ param[:] *= 0.001
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from transformers import WhisperTimeStampLogitsProcessor
5
+
6
+
7
+ def remove_fake_elements(inputs, per_group_sizes):
8
+ max_spks = per_group_sizes.max()
9
+ number_of_groups = per_group_sizes.shape[0]
10
+ outputs = []
11
+ inputs = inputs.view(number_of_groups, max_spks, *inputs.shape[1:])
12
+ for i, group_size in enumerate(per_group_sizes):
13
+ outputs.append(inputs[i, :group_size])
14
+ outputs = torch.cat(outputs, dim=0)
15
+ return outputs
16
+
17
+
18
+ class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):
19
+ def __init__(
20
+ self, generate_config, begin_index: Optional[int] = None,
21
+ _detect_timestamp_from_logprob: Optional[bool] = None
22
+ ): # support for the kwargs
23
+ self.no_timestamps_token_id = generate_config.no_timestamps_token_id
24
+ self.timestamp_begin = generate_config.no_timestamps_token_id + 1
25
+ self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id
26
+
27
+ # this variable is mostly just used for testing
28
+ self._detect_timestamp_from_logprob = (
29
+ _detect_timestamp_from_logprob
30
+ if _detect_timestamp_from_logprob is not None
31
+ else getattr(generate_config, "_detect_timestamp_from_logprob", True)
32
+ )
33
+
34
+ num_forced_ids = (
35
+ len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
36
+ )
37
+ self.begin_index = begin_index or (num_forced_ids + 1)
38
+
39
+ self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
40
+ self.min_initial_timestamp_index = getattr(generate_config, "min_initial_timestamp_index", None)
41
+ # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
42
+ # self.max_initial_timestamp_index = 50
43
+
44
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
45
+ # suppress <|notimestamps|> which is handled by without_timestamps
46
+ scores_processed = scores.clone()
47
+ scores_processed[:, self.no_timestamps_token_id] = -float("inf")
48
+
49
+ # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
50
+ for k in range(input_ids.shape[0]):
51
+ sampled_tokens = input_ids[k, self.begin_index:]
52
+ seq = list(sampled_tokens.tolist())
53
+
54
+ last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
55
+ penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
56
+
57
+ if last_was_timestamp:
58
+ if penultimate_was_timestamp: # has to be non-timestamp
59
+ scores_processed[k, self.timestamp_begin:] = -float("inf")
60
+ else: # cannot be normal text tokens
61
+ scores_processed[k, : self.eos_token_id] = -float("inf")
62
+
63
+ timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
64
+ if timestamps.numel() > 0:
65
+ # `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
66
+ # The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
67
+ if last_was_timestamp and not penultimate_was_timestamp:
68
+ timestamp_last = timestamps[-1]
69
+ else:
70
+ # Avoid to emit <|0.00|> again
71
+ timestamp_last = timestamps[-1] + 1
72
+
73
+ scores_processed[k, self.timestamp_begin: timestamp_last] = -float("inf")
74
+
75
+ # apply the `max_initial_timestamp` option
76
+ if input_ids.shape[1] == self.begin_index:
77
+ eos_scores = scores_processed[:, self.eos_token_id].clone()
78
+ scores_processed[:, : self.timestamp_begin] = -float("inf")
79
+ scores_processed[:, self.eos_token_id] = eos_scores
80
+
81
+ if self.max_initial_timestamp_index is not None:
82
+ last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
83
+ scores_processed[:, last_allowed + 1:] = -float("inf")
84
+ if self.min_initial_timestamp_index is not None:
85
+ first_allowed = self.timestamp_begin + self.min_initial_timestamp_index
86
+ scores_processed[:, self.timestamp_begin:first_allowed] = -float("inf")
87
+
88
+ # if sum of probability over timestamps is above any other token, sample timestamp
89
+ logprobs = torch.nn.functional.log_softmax(scores_processed.float(), dim=-1)
90
+ for k in range(input_ids.shape[0]):
91
+ timestamp_logprob = logprobs[k, self.timestamp_begin:].logsumexp(dim=-1)
92
+ max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
93
+ if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
94
+ scores_processed[k, : self.timestamp_begin] = -float("inf")
95
+
96
+ return scores_processed