import os import io import subprocess from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware import torchaudio import torch from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import re def clean_phoneme_string(raw: str) -> str: """Insert spaces between adjacent uppercase phoneme characters""" return re.sub(r"(?<=[A-Z])(?=[A-Z])", " ", raw).strip() # Use writable cache paths os.environ['HF_HOME'] = '/app/cache' os.environ['TORCH_HOME'] = '/app/cache' # FastAPI setup app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:8080"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load models try: # Phoneme model phoneme_processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme") phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme") # Speech-to-text model stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") print("✅ Models loaded successfully.") except Exception as e: print("❌ Model load error:", str(e)) raise # Convert webm audio to wav using ffmpeg def convert_webm_to_wav(webm_bytes: bytes) -> io.BytesIO: process = subprocess.run( ["ffmpeg", "-i", "pipe:0", "-f", "wav", "pipe:1"], input=webm_bytes, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) if process.returncode != 0: print("❌ ffmpeg error:", process.stderr.decode()) raise RuntimeError("ffmpeg conversion failed") return io.BytesIO(process.stdout) @app.post("/api/transcribe") async def transcribe(audio: UploadFile = File(...)): try: contents = await audio.read() wav_io = convert_webm_to_wav(contents) waveform, sample_rate = torchaudio.load(wav_io) if sample_rate != 16000: waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) sample_rate = 16000 # Run phoneme model phoneme_inputs = phoneme_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values with torch.no_grad(): phoneme_logits = phoneme_model(phoneme_inputs).logits phoneme_ids = torch.argmax(phoneme_logits, dim=-1) raw_phonemes = phoneme_processor.decode(phoneme_ids[0]) phonemes = clean_phoneme_string(raw_phonemes) # Run speech-to-text model stt_inputs = stt_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values with torch.no_grad(): stt_logits = stt_model(stt_inputs).logits stt_ids = torch.argmax(stt_logits, dim=-1) transcript = stt_processor.decode(stt_ids[0]) return { "phonemes": phonemes, "transcript": transcript } except Exception as e: print("❌ Transcription error:", str(e)) return { "phonemes": "[Error]", "transcript": "[Error: " + str(e) + "]" } @app.get("/") def root(): return {"message": "Backend is running"}