Lakoc commited on
Commit
e1af3d4
·
verified ·
1 Parent(s): 57fe226

Update generation.py

Browse files
Files changed (1) hide show
  1. generation.py +3 -2
generation.py CHANGED
@@ -1197,8 +1197,9 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration):
1197
  self.vad_seek_callback(kwargs["stno_mask"])
1198
  if "is_valid" in kwargs:
1199
  kwargs['is_valid'] = kwargs["is_valid"][batch_idx_map]
1200
- kwargs['labels'] = kwargs["labels"][batch_idx_map]
1201
- kwargs['upp_labels'] = kwargs["upp_labels"][batch_idx_map]
 
1202
  return kwargs
1203
 
1204
  def generate_with_fallback(
 
1197
  self.vad_seek_callback(kwargs["stno_mask"])
1198
  if "is_valid" in kwargs:
1199
  kwargs['is_valid'] = kwargs["is_valid"][batch_idx_map]
1200
+ if "labels" in kwargs:
1201
+ kwargs['labels'] = kwargs["labels"][batch_idx_map]
1202
+ kwargs['upp_labels'] = kwargs["upp_labels"][batch_idx_map]
1203
  return kwargs
1204
 
1205
  def generate_with_fallback(