Spaces:
Running
Running
import os | |
import io | |
import subprocess | |
import re | |
from difflib import SequenceMatcher | |
from fastapi import FastAPI, UploadFile, File | |
from fastapi.middleware.cors import CORSMiddleware | |
import torchaudio | |
import torch | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import nltk | |
from nltk.corpus import cmudict | |
# NLTK config | |
nltk.data.path.append("/app/cache/nltk_data") | |
cmu = cmudict.dict() | |
# Build reverse phoneme-to-word dictionary | |
phoneme_to_word = {} | |
for word, phoneme_lists in cmu.items(): | |
for phonemes in phoneme_lists: | |
key = " ".join(phonemes) | |
phoneme_to_word.setdefault(key, []).append(word) | |
def clean_phoneme_string(raw: str) -> str: | |
return re.sub(r"(?<=[A-Z])(?=[A-Z])", " ", raw).strip() | |
def guess_word_from_phonemes(phoneme_string: str) -> str: | |
key = phoneme_string.upper() | |
return phoneme_to_word.get(key, ["[unknown]"])[0] | |
def words_are_close(word1: str, word2: str, threshold: float = 0.8) -> bool: | |
return SequenceMatcher(None, word1.lower(), word2.lower()).ratio() >= threshold | |
# Environment paths | |
os.environ['HF_HOME'] = '/app/cache' | |
os.environ['TORCH_HOME'] = '/app/cache' | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:8080"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load models | |
try: | |
phoneme_processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme") | |
phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme") | |
stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
print("β Models loaded successfully.") | |
except Exception as e: | |
print("β Model load error:", str(e)) | |
raise | |
def convert_webm_to_wav(webm_bytes: bytes) -> io.BytesIO: | |
process = subprocess.run( | |
["ffmpeg", "-i", "pipe:0", "-f", "wav", "pipe:1"], | |
input=webm_bytes, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE | |
) | |
if process.returncode != 0: | |
print("β ffmpeg error:", process.stderr.decode()) | |
raise RuntimeError("ffmpeg conversion failed") | |
return io.BytesIO(process.stdout) | |
async def transcribe(audio: UploadFile = File(...)): | |
try: | |
contents = await audio.read() | |
wav_io = convert_webm_to_wav(contents) | |
waveform, sample_rate = torchaudio.load(wav_io) | |
if sample_rate != 16000: | |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
sample_rate = 16000 | |
# Run phoneme model | |
phoneme_inputs = phoneme_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values | |
with torch.no_grad(): | |
phoneme_logits = phoneme_model(phoneme_inputs).logits | |
phoneme_ids = torch.argmax(phoneme_logits, dim=-1) | |
raw_phonemes = phoneme_processor.decode(phoneme_ids[0]) | |
phonemes = clean_phoneme_string(raw_phonemes) | |
phoneme_guess = guess_word_from_phonemes(phonemes) | |
# Run speech-to-text model | |
stt_inputs = stt_processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values | |
with torch.no_grad(): | |
stt_logits = stt_model(stt_inputs).logits | |
stt_ids = torch.argmax(stt_logits, dim=-1) | |
transcript = stt_processor.decode(stt_ids[0]) | |
# Decide what to return | |
if words_are_close(transcript, phoneme_guess): | |
resolved = transcript | |
else: | |
resolved = phonemes | |
return { | |
"phonemes": phonemes, | |
"phoneme_guess": phoneme_guess, | |
"transcript": transcript, | |
"resolved": resolved | |
} | |
except Exception as e: | |
print("β Transcription error:", str(e)) | |
return { | |
"phonemes": "[Error]", | |
"phoneme_guess": "[Error]", | |
"transcript": "[Error]", | |
"resolved": "[Error]" | |
} | |
def root(): | |
return {"message": "Backend is running"} | |