import os import io import subprocess import re from difflib import SequenceMatcher from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware import torchaudio import torch from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import nltk from nltk.corpus import cmudict # NLTK config nltk.data.path.append("/app/cache/nltk_data") cmu = cmudict.dict() # Build reverse phoneme-to-word dictionary phoneme_to_word = {} for word, phoneme_lists in cmu.items(): for phonemes in phoneme_lists: key = " ".join(phonemes) phoneme_to_word.setdefault(key, []).append(word) def clean_phoneme_string(raw: str) -> str: return re.sub(r"(?<=[A-Z])(?=[A-Z])", " ", raw).strip() def guess_word_from_phonemes(phoneme_string: str) -> str: key = phoneme_string.upper() return phoneme_to_word.get(key, ["[unknown]"])[0] def words_are_close(word1: str, word2: str, threshold: float = 0.8) -> bool: return SequenceMatcher(None, word1.lower(), word2.lower()).ratio() >= threshold # Environment paths os.environ['HF_HOME'] = '/app/cache' os.environ['TORCH_HOME'] = '/app/cache' app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:8080"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load models try: phoneme_processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme") phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme") stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") print("✅ Models loaded successfully.") except Exception as e: print("❌ Model load error:", str(e)) raise def convert_webm_to_wav(webm_bytes: bytes) -> io.BytesIO: process = subprocess.run( ["ffmpeg", "-i", "pipe:0", "-f", "wav", "pipe:1"], input=webm_bytes, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) if process.returncode != 0: print("❌ ffmpeg error:", process.stderr.decode()) raise RuntimeError("ffmpeg conversion failed") return io.BytesIO(process.stdout) @app.post("/api/transcribe") async def transcribe(audio: UploadFile = File(...)): try: contents = await audio.read() wav_io = convert_webm_to_wav(contents) waveform, sample_rate = torchaudio.load(wav_io) if sample_rate != 16000: waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) sample_rate = 16000 # Run phoneme model phoneme_inputs = phoneme_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values with torch.no_grad(): phoneme_logits = phoneme_model(phoneme_inputs).logits phoneme_ids = torch.argmax(phoneme_logits, dim=-1) raw_phonemes = phoneme_processor.decode(phoneme_ids[0]) phonemes = clean_phoneme_string(raw_phonemes) phoneme_guess = guess_word_from_phonemes(phonemes) # Run speech-to-text model stt_inputs = stt_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values with torch.no_grad(): stt_logits = stt_model(stt_inputs).logits stt_ids = torch.argmax(stt_logits, dim=-1) transcript = stt_processor.decode(stt_ids[0]) # Decide what to return if words_are_close(transcript, phoneme_guess): resolved = transcript else: resolved = phonemes return { "phonemes": phonemes, "phoneme_guess": phoneme_guess, "transcript": transcript, "resolved": resolved } except Exception as e: print("❌ Transcription error:", str(e)) return { "phonemes": "[Error]", "phoneme_guess": "[Error]", "transcript": "[Error]", "resolved": "[Error]" } @app.get("/") def root(): return {"message": "Backend is running"}