Spaces:
Running
Running
File size: 3,383 Bytes
2db3ee9 5ce9e9f 2db3ee9 d197b93 5ad53ef 88dc312 d197b93 2db3ee9 e6ffe57 2dea768 d197b93 2db3ee9 4adedcb 5ad53ef c5a833f d197b93 c5a833f d181646 d197b93 c5a833f 4e58e35 c5a833f 0671f09 d181646 2db3ee9 c5a833f 2db3ee9 d197b93 5ad53ef d181646 5ad53ef d181646 5ad53ef e6ffe57 4e58e35 5ad53ef d181646 5ad53ef 544543c 5ad53ef 4e58e35 5ad53ef 4e58e35 5ad53ef d78328b 5ad53ef 4e58e35 d181646 5ad53ef d181646 d197b93 c5a833f d181646 4adedcb c5a833f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import io
import subprocess
from difflib import SequenceMatcher
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import torchaudio
import torch
from phonemizer import phonemize
from faster_whisper import WhisperModel
# Set cache paths
os.environ['HF_HOME'] = '/app/cache'
os.environ['TORCH_HOME'] = '/app/cache'
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
# Load models
phon_proc = torch.hub.load('pytorch/fairseq', 'wav2vec2_vox_960h') # Phoneme model is assumed custom or replace as needed
whisper_model = WhisperModel("small", compute_type="float32")
def convert_webm_to_wav(bts):
p = subprocess.run(["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000", "-ac", "1", "pipe:1"],
input=bts, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if p.returncode != 0:
raise RuntimeError(p.stderr.decode())
return io.BytesIO(p.stdout)
def normalize_phoneme_string(s):
return ''.join(c for c in s if c.isalpha())
def words_sim(a, b, threshold=0.35):
return SequenceMatcher(None, a.lower(), b.lower()).ratio() >= threshold
@app.post("/api/transcribe")
async def transcribe(audio: UploadFile = File(...)):
data = await audio.read()
wav_io = convert_webm_to_wav(data)
waveform, sample_rate = torchaudio.load(wav_io)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
sample_rate = 16000
# Whisper transcription with timestamps
segments, _ = whisper_model.transcribe(wav_io, word_timestamps=True)
words = []
transcript = []
for seg in segments:
for w in seg.words:
words.append(w)
transcript.append(w.word)
full_text = " ".join(transcript)
# Compare each word
resolved_output = []
phoneme_text_map = phonemize(transcript, language='en-us', backend='espeak', strip=True)
for i, word in enumerate(words):
w_text = word.word
start, end = word.start, word.end
if not w_text or start is None or end is None:
resolved_output.append("[missing]")
continue
start_sample = int(start * sample_rate)
end_sample = int(end * sample_rate)
segment = waveform[:, start_sample:end_sample]
# Run through phoneme model (you may replace with appropriate API if using a different model)
with torch.no_grad():
emissions, _ = phon_proc.extract_features(segment.squeeze(), padding_mask=None)
phoneme_str = "".join(phon_proc.decode(emissions.argmax(dim=-1).tolist()))
# Compare to expected
expected = phoneme_text_map[i] if i < len(phoneme_text_map) else ""
sim = SequenceMatcher(None, normalize_phoneme_string(phoneme_str), normalize_phoneme_string(expected)).ratio()
print(f"Word: {w_text} | spoken: {phoneme_str} | expected: {expected} | sim: {sim:.2f}")
if sim >= 0.35:
resolved_output.append(w_text)
else:
resolved_output.append(phoneme_str)
return {
"transcript": full_text,
"phonemes": "[per word, internal]",
"resolved": " ".join(resolved_output)
}
@app.get("/")
def root():
return {"message": "Backend is running"}
|