File size: 3,383 Bytes
2db3ee9
 
 
5ce9e9f
2db3ee9
 
 
 
 
d197b93
5ad53ef
88dc312
d197b93
2db3ee9
e6ffe57
2dea768
 
d197b93
2db3ee9
4adedcb
5ad53ef
 
c5a833f
d197b93
 
c5a833f
d181646
d197b93
 
 
 
c5a833f
4e58e35
 
 
c5a833f
0671f09
d181646
2db3ee9
c5a833f
2db3ee9
 
d197b93
5ad53ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d181646
5ad53ef
d181646
 
5ad53ef
 
 
e6ffe57
 
4e58e35
5ad53ef
 
 
d181646
5ad53ef
 
 
 
544543c
5ad53ef
 
 
4e58e35
5ad53ef
4e58e35
5ad53ef
 
d78328b
5ad53ef
4e58e35
d181646
5ad53ef
 
 
d181646
d197b93
c5a833f
d181646
 
 
4adedcb
c5a833f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import io
import subprocess
from difflib import SequenceMatcher

from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import torchaudio
import torch
from phonemizer import phonemize
from faster_whisper import WhisperModel

# Set cache paths
os.environ['HF_HOME'] = '/app/cache'
os.environ['TORCH_HOME'] = '/app/cache'

app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])

# Load models
phon_proc = torch.hub.load('pytorch/fairseq', 'wav2vec2_vox_960h')  # Phoneme model is assumed custom or replace as needed
whisper_model = WhisperModel("small", compute_type="float32")


def convert_webm_to_wav(bts):
    p = subprocess.run(["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000", "-ac", "1", "pipe:1"],
                       input=bts, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if p.returncode != 0:
        raise RuntimeError(p.stderr.decode())
    return io.BytesIO(p.stdout)


def normalize_phoneme_string(s):
    return ''.join(c for c in s if c.isalpha())


def words_sim(a, b, threshold=0.35):
    return SequenceMatcher(None, a.lower(), b.lower()).ratio() >= threshold


@app.post("/api/transcribe")
async def transcribe(audio: UploadFile = File(...)):
    data = await audio.read()
    wav_io = convert_webm_to_wav(data)
    waveform, sample_rate = torchaudio.load(wav_io)
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
        sample_rate = 16000

    # Whisper transcription with timestamps
    segments, _ = whisper_model.transcribe(wav_io, word_timestamps=True)
    words = []
    transcript = []
    for seg in segments:
        for w in seg.words:
            words.append(w)
            transcript.append(w.word)

    full_text = " ".join(transcript)

    # Compare each word
    resolved_output = []
    phoneme_text_map = phonemize(transcript, language='en-us', backend='espeak', strip=True)

    for i, word in enumerate(words):
        w_text = word.word
        start, end = word.start, word.end
        if not w_text or start is None or end is None:
            resolved_output.append("[missing]")
            continue

        start_sample = int(start * sample_rate)
        end_sample = int(end * sample_rate)
        segment = waveform[:, start_sample:end_sample]

        # Run through phoneme model (you may replace with appropriate API if using a different model)
        with torch.no_grad():
            emissions, _ = phon_proc.extract_features(segment.squeeze(), padding_mask=None)
        phoneme_str = "".join(phon_proc.decode(emissions.argmax(dim=-1).tolist()))

        # Compare to expected
        expected = phoneme_text_map[i] if i < len(phoneme_text_map) else ""
        sim = SequenceMatcher(None, normalize_phoneme_string(phoneme_str), normalize_phoneme_string(expected)).ratio()

        print(f"Word: {w_text} | spoken: {phoneme_str} | expected: {expected} | sim: {sim:.2f}")

        if sim >= 0.35:
            resolved_output.append(w_text)
        else:
            resolved_output.append(phoneme_str)

    return {
        "transcript": full_text,
        "phonemes": "[per word, internal]",
        "resolved": " ".join(resolved_output)
    }


@app.get("/")
def root():
    return {"message": "Backend is running"}