import os import io import subprocess from difflib import SequenceMatcher from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware import torchaudio import torch from phonemizer import phonemize from faster_whisper import WhisperModel # Set cache paths os.environ['HF_HOME'] = '/app/cache' os.environ['TORCH_HOME'] = '/app/cache' app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) # Load models phon_proc = torch.hub.load('pytorch/fairseq', 'wav2vec2_vox_960h') # Phoneme model is assumed custom or replace as needed whisper_model = WhisperModel("small", compute_type="float32") def convert_webm_to_wav(bts): p = subprocess.run(["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000", "-ac", "1", "pipe:1"], input=bts, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if p.returncode != 0: raise RuntimeError(p.stderr.decode()) return io.BytesIO(p.stdout) def normalize_phoneme_string(s): return ''.join(c for c in s if c.isalpha()) def words_sim(a, b, threshold=0.35): return SequenceMatcher(None, a.lower(), b.lower()).ratio() >= threshold @app.post("/api/transcribe") async def transcribe(audio: UploadFile = File(...)): data = await audio.read() wav_io = convert_webm_to_wav(data) waveform, sample_rate = torchaudio.load(wav_io) if sample_rate != 16000: waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) sample_rate = 16000 # Whisper transcription with timestamps segments, _ = whisper_model.transcribe(wav_io, word_timestamps=True) words = [] transcript = [] for seg in segments: for w in seg.words: words.append(w) transcript.append(w.word) full_text = " ".join(transcript) # Compare each word resolved_output = [] phoneme_text_map = phonemize(transcript, language='en-us', backend='espeak', strip=True) for i, word in enumerate(words): w_text = word.word start, end = word.start, word.end if not w_text or start is None or end is None: resolved_output.append("[missing]") continue start_sample = int(start * sample_rate) end_sample = int(end * sample_rate) segment = waveform[:, start_sample:end_sample] # Run through phoneme model (you may replace with appropriate API if using a different model) with torch.no_grad(): emissions, _ = phon_proc.extract_features(segment.squeeze(), padding_mask=None) phoneme_str = "".join(phon_proc.decode(emissions.argmax(dim=-1).tolist())) # Compare to expected expected = phoneme_text_map[i] if i < len(phoneme_text_map) else "" sim = SequenceMatcher(None, normalize_phoneme_string(phoneme_str), normalize_phoneme_string(expected)).ratio() print(f"Word: {w_text} | spoken: {phoneme_str} | expected: {expected} | sim: {sim:.2f}") if sim >= 0.35: resolved_output.append(w_text) else: resolved_output.append(phoneme_str) return { "transcript": full_text, "phonemes": "[per word, internal]", "resolved": " ".join(resolved_output) } @app.get("/") def root(): return {"message": "Backend is running"}