fonetik-fast / app.py
greg0rs's picture
Update app.py
5ce9e9f verified
raw
history blame
4.15 kB
import os
import io
import subprocess
import re
from difflib import SequenceMatcher
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import torchaudio
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import nltk
from nltk.corpus import cmudict
# NLTK config
nltk.data.path.append("/app/cache/nltk_data")
cmu = cmudict.dict()
# Build reverse phoneme-to-word dictionary
phoneme_to_word = {}
for word, phoneme_lists in cmu.items():
for phonemes in phoneme_lists:
key = " ".join(phonemes)
phoneme_to_word.setdefault(key, []).append(word)
def clean_phoneme_string(raw: str) -> str:
return re.sub(r"(?<=[A-Z])(?=[A-Z])", " ", raw).strip()
def guess_word_from_phonemes(phoneme_string: str) -> str:
key = phoneme_string.upper()
return phoneme_to_word.get(key, ["[unknown]"])[0]
def words_are_close(word1: str, word2: str, threshold: float = 0.8) -> bool:
return SequenceMatcher(None, word1.lower(), word2.lower()).ratio() >= threshold
# Environment paths
os.environ['HF_HOME'] = '/app/cache'
os.environ['TORCH_HOME'] = '/app/cache'
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:8080"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load models
try:
phoneme_processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
print("βœ… Models loaded successfully.")
except Exception as e:
print("❌ Model load error:", str(e))
raise
def convert_webm_to_wav(webm_bytes: bytes) -> io.BytesIO:
process = subprocess.run(
["ffmpeg", "-i", "pipe:0", "-f", "wav", "pipe:1"],
input=webm_bytes,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
if process.returncode != 0:
print("❌ ffmpeg error:", process.stderr.decode())
raise RuntimeError("ffmpeg conversion failed")
return io.BytesIO(process.stdout)
@app.post("/api/transcribe")
async def transcribe(audio: UploadFile = File(...)):
try:
contents = await audio.read()
wav_io = convert_webm_to_wav(contents)
waveform, sample_rate = torchaudio.load(wav_io)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
sample_rate = 16000
# Run phoneme model
phoneme_inputs = phoneme_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values
with torch.no_grad():
phoneme_logits = phoneme_model(phoneme_inputs).logits
phoneme_ids = torch.argmax(phoneme_logits, dim=-1)
raw_phonemes = phoneme_processor.decode(phoneme_ids[0])
phonemes = clean_phoneme_string(raw_phonemes)
phoneme_guess = guess_word_from_phonemes(phonemes)
# Run speech-to-text model
stt_inputs = stt_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values
with torch.no_grad():
stt_logits = stt_model(stt_inputs).logits
stt_ids = torch.argmax(stt_logits, dim=-1)
transcript = stt_processor.decode(stt_ids[0])
# Decide what to return
if words_are_close(transcript, phoneme_guess):
resolved = transcript
else:
resolved = phonemes
return {
"phonemes": phonemes,
"phoneme_guess": phoneme_guess,
"transcript": transcript,
"resolved": resolved
}
except Exception as e:
print("❌ Transcription error:", str(e))
return {
"phonemes": "[Error]",
"phoneme_guess": "[Error]",
"transcript": "[Error]",
"resolved": "[Error]"
}
@app.get("/")
def root():
return {"message": "Backend is running"}