import torch import torchaudio from fastapi import FastAPI, UploadFile, File from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import io app = FastAPI() # Load Wav2Vec2 model and processor processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft") @app.post("/transcribe/") async def transcribe_audio(file: UploadFile = File(...)): try: # Load audio file audio_bytes = await file.read() audio_input, sample_rate = torchaudio.load(io.BytesIO(audio_bytes)) # Convert stereo to mono (if needed) if audio_input.shape[0] > 1: audio_input = torch.mean(audio_input, dim=0, keepdim=True) # Resample to 16 kHz (if needed) target_sample_rate = 16000 if sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) audio_input = resampler(audio_input) # Remove batch dimension audio_input = audio_input.squeeze(0) # Preprocess the audio input_values = processor(audio_input, sampling_rate=target_sample_rate, return_tensors="pt").input_values # Run inference with torch.no_grad(): logits = model(input_values).logits # Decode the predicted tokens predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids) return {"transcription": transcription[0]} except Exception as e: return {"error": str(e)}