Update modeling_dicow.py
Browse files- modeling_dicow.py +0 -29
modeling_dicow.py
CHANGED
|
@@ -274,35 +274,6 @@ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalG
|
|
| 274 |
>>> transcription
|
| 275 |
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
| 276 |
```"""
|
| 277 |
-
stno_mask_orig = stno_mask
|
| 278 |
-
enrollments_processed = None
|
| 279 |
-
enroll_stno_mask_reshape = None
|
| 280 |
-
enrollments_enc = None
|
| 281 |
-
if self.training and self.use_enrollment_network:
|
| 282 |
-
attention_mask = attention_mask[::2, ...]
|
| 283 |
-
|
| 284 |
-
enroll_input = input_features[1::2, ...]
|
| 285 |
-
input_features = input_features[::2, ...]
|
| 286 |
-
|
| 287 |
-
is_valid = is_valid[::2, ...]
|
| 288 |
-
enroll_stno_mask = stno_mask[1::2, ...]
|
| 289 |
-
stno_mask = stno_mask[::2, ...]
|
| 290 |
-
|
| 291 |
-
labels = labels[::2, ...]
|
| 292 |
-
upp_labels = upp_labels[::2, ...]
|
| 293 |
-
enrollments_enc = self.model.encoder.encode_enrollment(
|
| 294 |
-
input_features=enroll_input,
|
| 295 |
-
num_layers_to_apply=self.config.spk_embedding_extraction_layer,
|
| 296 |
-
head_mask=head_mask,
|
| 297 |
-
stno_mask=enroll_stno_mask,
|
| 298 |
-
)
|
| 299 |
-
enroll_stno_mask_reshape = ((enroll_stno_mask[:, 1, :] + enroll_stno_mask[:, 3, :]) > 0.5).view(-1,
|
| 300 |
-
self.config.mt_num_speakers,
|
| 301 |
-
enroll_stno_mask.shape[
|
| 302 |
-
2]).flatten(1,
|
| 303 |
-
2)
|
| 304 |
-
enrollments_processed = enrollments_enc.view(-1, self.config.mt_num_speakers, enrollments_enc.shape[1],
|
| 305 |
-
enrollments_enc.shape[2]).flatten(1, 2)
|
| 306 |
|
| 307 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 308 |
|
|
|
|
| 274 |
>>> transcription
|
| 275 |
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
| 276 |
```"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 279 |
|