Commit 
							
							·
						
						77b8369
	
1
								Parent(s):
							
							cf8f9cf
								
Speculative Decoding doesn't work yet with Whisper-v3 (#23)
Browse files- Speculative Decoding doesn't work yet with Whisper-v3 (3264a14fc680cf148949b17860f045ce2cf81edb)
    	
        README.md
    CHANGED
    
    | @@ -258,57 +258,6 @@ result = pipe(sample, return_timestamps=True, generate_kwargs={"language": "fren | |
| 258 | 
             
            print(result["chunks"])
         | 
| 259 | 
             
            ```
         | 
| 260 |  | 
| 261 | 
            -
            ## Speculative Decoding
         | 
| 262 | 
            -
             | 
| 263 | 
            -
            Whisper `tiny` can be used as an assistant model to Whisper for speculative decoding. Speculative decoding mathematically
         | 
| 264 | 
            -
            ensures the exact same outputs as Whisper are obtained while being 2 times faster. This makes it the perfect drop-in 
         | 
| 265 | 
            -
            replacement for existing Whisper pipelines, since the same outputs are guaranteed.
         | 
| 266 | 
            -
             | 
| 267 | 
            -
            In the following code-snippet, we load the assistant Distil-Whisper model standalone to the main Whisper pipeline. We then
         | 
| 268 | 
            -
            specify it as the "assistant model" for generation:
         | 
| 269 | 
            -
             | 
| 270 | 
            -
            ```python
         | 
| 271 | 
            -
            from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
         | 
| 272 | 
            -
            import torch
         | 
| 273 | 
            -
            from datasets import load_dataset
         | 
| 274 | 
            -
             | 
| 275 | 
            -
            device = "cuda:0" if torch.cuda.is_available() else "cpu"
         | 
| 276 | 
            -
            torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
         | 
| 277 | 
            -
             | 
| 278 | 
            -
            assistant_model_id = "openai/whisper-tiny"
         | 
| 279 | 
            -
             | 
| 280 | 
            -
            assistant_model = AutoModelForCausalLM.from_pretrained(
         | 
| 281 | 
            -
                assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
         | 
| 282 | 
            -
            )
         | 
| 283 | 
            -
            assistant_model.to(device)
         | 
| 284 | 
            -
             | 
| 285 | 
            -
            model_id = "openai/whisper-large-v3"
         | 
| 286 | 
            -
             | 
| 287 | 
            -
            model = AutoModelForSpeechSeq2Seq.from_pretrained(
         | 
| 288 | 
            -
                model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
         | 
| 289 | 
            -
            )
         | 
| 290 | 
            -
            model.to(device)
         | 
| 291 | 
            -
             | 
| 292 | 
            -
            processor = AutoProcessor.from_pretrained(model_id)
         | 
| 293 | 
            -
             | 
| 294 | 
            -
            pipe = pipeline(
         | 
| 295 | 
            -
                "automatic-speech-recognition",
         | 
| 296 | 
            -
                model=model,
         | 
| 297 | 
            -
                tokenizer=processor.tokenizer,
         | 
| 298 | 
            -
                feature_extractor=processor.feature_extractor,
         | 
| 299 | 
            -
                max_new_tokens=128,
         | 
| 300 | 
            -
                generate_kwargs={"assistant_model": assistant_model},
         | 
| 301 | 
            -
                torch_dtype=torch_dtype,
         | 
| 302 | 
            -
                device=device,
         | 
| 303 | 
            -
            )
         | 
| 304 | 
            -
             | 
| 305 | 
            -
            dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
         | 
| 306 | 
            -
            sample = dataset[0]["audio"]
         | 
| 307 | 
            -
             | 
| 308 | 
            -
            result = pipe(sample)
         | 
| 309 | 
            -
            print(result["text"])
         | 
| 310 | 
            -
            ```
         | 
| 311 | 
            -
             | 
| 312 | 
             
            ## Additional Speed & Memory Improvements
         | 
| 313 |  | 
| 314 | 
             
            You can apply additional speed and memory improvements to Whisper-large-v3 which we cover in the following.
         | 
|  | |
| 258 | 
             
            print(result["chunks"])
         | 
| 259 | 
             
            ```
         | 
| 260 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 261 | 
             
            ## Additional Speed & Memory Improvements
         | 
| 262 |  | 
| 263 | 
             
            You can apply additional speed and memory improvements to Whisper-large-v3 which we cover in the following.
         | 

