Update generation.py
Browse files- generation.py +3 -2
    	
        generation.py
    CHANGED
    
    | @@ -1197,8 +1197,9 @@ class DiCoWGenerationMixin(WhisperForConditionalGeneration): | |
| 1197 | 
             
                        self.vad_seek_callback(kwargs["stno_mask"])
         | 
| 1198 | 
             
                    if "is_valid" in kwargs:
         | 
| 1199 | 
             
                        kwargs['is_valid'] = kwargs["is_valid"][batch_idx_map]
         | 
| 1200 | 
            -
                     | 
| 1201 | 
            -
             | 
|  | |
| 1202 | 
             
                    return kwargs
         | 
| 1203 |  | 
| 1204 | 
             
                def generate_with_fallback(
         | 
|  | |
| 1197 | 
             
                        self.vad_seek_callback(kwargs["stno_mask"])
         | 
| 1198 | 
             
                    if "is_valid" in kwargs:
         | 
| 1199 | 
             
                        kwargs['is_valid'] = kwargs["is_valid"][batch_idx_map]
         | 
| 1200 | 
            +
                    if "labels" in kwargs:
         | 
| 1201 | 
            +
                        kwargs['labels'] = kwargs["labels"][batch_idx_map]
         | 
| 1202 | 
            +
                        kwargs['upp_labels'] = kwargs["upp_labels"][batch_idx_map]
         | 
| 1203 | 
             
                    return kwargs
         | 
| 1204 |  | 
| 1205 | 
             
                def generate_with_fallback(
         | 
