import os import torch import numpy as np from transformers import ( AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, ) from transformers.utils import is_flash_attn_2_available logger = logging.getLogger(__name__) MODEL_ID = "openai/whisper-large-v3-turbo" LANGUAGE = "english" device = "cuda" use_device_map = True try_compile_model = True try_use_flash_attention = True torch_dtype = torch.float16 np_dtype = np.float16 # Initialize the model (use flash attention on cuda if possible) try: model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa", device_map="auto" if use_device_map else None, ) if not use_device_map: model.to(device) except RuntimeError as e: try: logger.warning("Falling back to device_map=None") model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa", device_map=None, ) model.to(device) except RuntimeError as e: try: logger.warning("Disabling flash attention") model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa", ) model.to(device) except Exception as e: logger.error(f"Error loading ASR model: {e}") logger.error(f"Are you providing a valid model ID? {MODEL_ID}") raise processor = AutoProcessor.from_pretrained(MODEL_ID) transcribe_pipeline = pipeline( task="automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype ) # Try to compile the model try: if try_compile_model: transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune") else: logger.warning("Proceeding without compiling the model (requirements not met)") except Exception as e: logger.warning(f"Error compiling model: {e}") logger.warning("Proceeding without compiling the model") # Warm up the model with empty audio logger.info("Warming up Whisper model with dummy input") warmup_audio = np.random.rand(16000).astype(np_dtype) transcribe_pipeline(warmup_audio) logger.info("Model warmup complete")