Spaces:
Running
Running
import os | |
import io | |
import subprocess | |
from difflib import SequenceMatcher | |
from fastapi import FastAPI, UploadFile, File | |
from fastapi.middleware.cors import CORSMiddleware | |
import torchaudio | |
import torch | |
from phonemizer import phonemize | |
from faster_whisper import WhisperModel | |
# Set cache paths | |
os.environ['HF_HOME'] = '/app/cache' | |
os.environ['TORCH_HOME'] = '/app/cache' | |
app = FastAPI() | |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
# Load models | |
phon_proc = torch.hub.load('pytorch/fairseq', 'wav2vec2_vox_960h') # Phoneme model is assumed custom or replace as needed | |
whisper_model = WhisperModel("small", compute_type="float32") | |
def convert_webm_to_wav(bts): | |
p = subprocess.run(["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000", "-ac", "1", "pipe:1"], | |
input=bts, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
if p.returncode != 0: | |
raise RuntimeError(p.stderr.decode()) | |
return io.BytesIO(p.stdout) | |
def normalize_phoneme_string(s): | |
return ''.join(c for c in s if c.isalpha()) | |
def words_sim(a, b, threshold=0.35): | |
return SequenceMatcher(None, a.lower(), b.lower()).ratio() >= threshold | |
async def transcribe(audio: UploadFile = File(...)): | |
data = await audio.read() | |
wav_io = convert_webm_to_wav(data) | |
waveform, sample_rate = torchaudio.load(wav_io) | |
if sample_rate != 16000: | |
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) | |
sample_rate = 16000 | |
# Whisper transcription with timestamps | |
segments, _ = whisper_model.transcribe(wav_io, word_timestamps=True) | |
words = [] | |
transcript = [] | |
for seg in segments: | |
for w in seg.words: | |
words.append(w) | |
transcript.append(w.word) | |
full_text = " ".join(transcript) | |
# Compare each word | |
resolved_output = [] | |
phoneme_text_map = phonemize(transcript, language='en-us', backend='espeak', strip=True) | |
for i, word in enumerate(words): | |
w_text = word.word | |
start, end = word.start, word.end | |
if not w_text or start is None or end is None: | |
resolved_output.append("[missing]") | |
continue | |
start_sample = int(start * sample_rate) | |
end_sample = int(end * sample_rate) | |
segment = waveform[:, start_sample:end_sample] | |
# Run through phoneme model (you may replace with appropriate API if using a different model) | |
with torch.no_grad(): | |
emissions, _ = phon_proc.extract_features(segment.squeeze(), padding_mask=None) | |
phoneme_str = "".join(phon_proc.decode(emissions.argmax(dim=-1).tolist())) | |
# Compare to expected | |
expected = phoneme_text_map[i] if i < len(phoneme_text_map) else "" | |
sim = SequenceMatcher(None, normalize_phoneme_string(phoneme_str), normalize_phoneme_string(expected)).ratio() | |
print(f"Word: {w_text} | spoken: {phoneme_str} | expected: {expected} | sim: {sim:.2f}") | |
if sim >= 0.35: | |
resolved_output.append(w_text) | |
else: | |
resolved_output.append(phoneme_str) | |
return { | |
"transcript": full_text, | |
"phonemes": "[per word, internal]", | |
"resolved": " ".join(resolved_output) | |
} | |
def root(): | |
return {"message": "Backend is running"} | |