fonetik-fast / app.py
greg0rs's picture
Update app.py
5ad53ef verified
raw
history blame
3.38 kB
import os
import io
import subprocess
from difflib import SequenceMatcher
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import torchaudio
import torch
from phonemizer import phonemize
from faster_whisper import WhisperModel
# Set cache paths
os.environ['HF_HOME'] = '/app/cache'
os.environ['TORCH_HOME'] = '/app/cache'
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
# Load models
phon_proc = torch.hub.load('pytorch/fairseq', 'wav2vec2_vox_960h') # Phoneme model is assumed custom or replace as needed
whisper_model = WhisperModel("small", compute_type="float32")
def convert_webm_to_wav(bts):
p = subprocess.run(["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000", "-ac", "1", "pipe:1"],
input=bts, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if p.returncode != 0:
raise RuntimeError(p.stderr.decode())
return io.BytesIO(p.stdout)
def normalize_phoneme_string(s):
return ''.join(c for c in s if c.isalpha())
def words_sim(a, b, threshold=0.35):
return SequenceMatcher(None, a.lower(), b.lower()).ratio() >= threshold
@app.post("/api/transcribe")
async def transcribe(audio: UploadFile = File(...)):
data = await audio.read()
wav_io = convert_webm_to_wav(data)
waveform, sample_rate = torchaudio.load(wav_io)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
sample_rate = 16000
# Whisper transcription with timestamps
segments, _ = whisper_model.transcribe(wav_io, word_timestamps=True)
words = []
transcript = []
for seg in segments:
for w in seg.words:
words.append(w)
transcript.append(w.word)
full_text = " ".join(transcript)
# Compare each word
resolved_output = []
phoneme_text_map = phonemize(transcript, language='en-us', backend='espeak', strip=True)
for i, word in enumerate(words):
w_text = word.word
start, end = word.start, word.end
if not w_text or start is None or end is None:
resolved_output.append("[missing]")
continue
start_sample = int(start * sample_rate)
end_sample = int(end * sample_rate)
segment = waveform[:, start_sample:end_sample]
# Run through phoneme model (you may replace with appropriate API if using a different model)
with torch.no_grad():
emissions, _ = phon_proc.extract_features(segment.squeeze(), padding_mask=None)
phoneme_str = "".join(phon_proc.decode(emissions.argmax(dim=-1).tolist()))
# Compare to expected
expected = phoneme_text_map[i] if i < len(phoneme_text_map) else ""
sim = SequenceMatcher(None, normalize_phoneme_string(phoneme_str), normalize_phoneme_string(expected)).ratio()
print(f"Word: {w_text} | spoken: {phoneme_str} | expected: {expected} | sim: {sim:.2f}")
if sim >= 0.35:
resolved_output.append(w_text)
else:
resolved_output.append(phoneme_str)
return {
"transcript": full_text,
"phonemes": "[per word, internal]",
"resolved": " ".join(resolved_output)
}
@app.get("/")
def root():
return {"message": "Backend is running"}