Lakoc commited on
Commit
fe4ba0e
·
verified ·
1 Parent(s): 897d695

Update modeling_dicow.py

Browse files
Files changed (1) hide show
  1. modeling_dicow.py +0 -29
modeling_dicow.py CHANGED
@@ -274,35 +274,6 @@ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalG
274
  >>> transcription
275
  ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
276
  ```"""
277
- stno_mask_orig = stno_mask
278
- enrollments_processed = None
279
- enroll_stno_mask_reshape = None
280
- enrollments_enc = None
281
- if self.training and self.use_enrollment_network:
282
- attention_mask = attention_mask[::2, ...]
283
-
284
- enroll_input = input_features[1::2, ...]
285
- input_features = input_features[::2, ...]
286
-
287
- is_valid = is_valid[::2, ...]
288
- enroll_stno_mask = stno_mask[1::2, ...]
289
- stno_mask = stno_mask[::2, ...]
290
-
291
- labels = labels[::2, ...]
292
- upp_labels = upp_labels[::2, ...]
293
- enrollments_enc = self.model.encoder.encode_enrollment(
294
- input_features=enroll_input,
295
- num_layers_to_apply=self.config.spk_embedding_extraction_layer,
296
- head_mask=head_mask,
297
- stno_mask=enroll_stno_mask,
298
- )
299
- enroll_stno_mask_reshape = ((enroll_stno_mask[:, 1, :] + enroll_stno_mask[:, 3, :]) > 0.5).view(-1,
300
- self.config.mt_num_speakers,
301
- enroll_stno_mask.shape[
302
- 2]).flatten(1,
303
- 2)
304
- enrollments_processed = enrollments_enc.view(-1, self.config.mt_num_speakers, enrollments_enc.shape[1],
305
- enrollments_enc.shape[2]).flatten(1, 2)
306
 
307
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
308
 
 
274
  >>> transcription
275
  ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
276
  ```"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
279