|
import torch |
|
import gradio as gr |
|
import speechbrain as sb |
|
import torchaudio |
|
from hyperpyyaml import load_hyperpyyaml |
|
from pyctcdecode import build_ctcdecoder |
|
import os |
|
|
|
|
|
hparams_file = "train.yaml" |
|
with open(hparams_file, "r") as fin: |
|
hparams = load_hyperpyyaml(fin) |
|
|
|
|
|
encoder = sb.dataio.encoder.CTCTextEncoder() |
|
|
|
encoder.load_or_create( |
|
path=hparams["encoder_file"], |
|
from_didatasets=[[]], |
|
output_key="char_list", |
|
special_labels=special_labels = {"blank_label":0,"unk_label": 1}, |
|
sequence_input=True, |
|
) |
|
|
|
|
|
ind2lab = encoder.ind2lab |
|
labels = [ind2lab[x] for x in range(len(ind2lab))] |
|
labels = [""] + labels[1:-1] + ["1"] |
|
|
|
|
|
decoder = build_ctcdecoder( |
|
labels, |
|
kenlm_model_path=hparams["ngram_lm_path"], |
|
alpha=0.5, |
|
beta=1.0, |
|
) |
|
|
|
|
|
|
|
class ASR(sb.core.Brain): |
|
def treat_wav(self, sig): |
|
"""Process a waveform and return the transcribed text.""" |
|
feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu")) |
|
feats = self.modules.enc(feats) |
|
logits = self.modules.ctc_lin(feats) |
|
p_ctc = self.hparams.log_softmax(logits) |
|
predicted_words = [] |
|
for logs in p_ctc: |
|
text = decoder.decode(logs.detach().cpu().numpy()) |
|
predicted_words.append(text.split(" ")) |
|
return " ".join(predicted_words[0]) |
|
|
|
|
|
|
|
asr_brain = ASR( |
|
modules=hparams["modules"], |
|
hparams=hparams, |
|
run_opts={"device": "cpu"}, |
|
checkpointer=hparams["checkpointer"], |
|
) |
|
asr_brain.tokenizer = encoder |
|
asr_brain.checkpointer.recover_if_possible() |
|
asr_brain.modules.eval() |
|
|
|
|
|
|
|
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"): |
|
if file_mic is not None: |
|
wav = file_mic |
|
elif file_upload is not None: |
|
wav = file_upload |
|
else: |
|
return "ERROR: You have to either use the microphone or upload an audio file" |
|
|
|
|
|
info = torchaudio.info(wav) |
|
sr = info.sample_rate |
|
sig = sb.dataio.dataio.read_audio(wav) |
|
if len(sig.shape) > 1: |
|
sig = torch.mean(sig, dim=1) |
|
sig = torch.unsqueeze(sig, 0) |
|
tensor_wav = sig.to(device) |
|
resampled = torchaudio.functional.resample(tensor_wav, sr, 16000) |
|
|
|
|
|
sentence = asr.treat_wav(resampled) |
|
return sentence |
|
|
|
|
|
|
|
title = "Tunisian Speech Recognition" |
|
description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %. |
|
\n |
|
\n Interesting isn\'t it !''' |
|
|
|
gr.Interface( |
|
fn=treat_wav_file, |
|
inputs=[ |
|
gr.Audio(sources="microphone", type='filepath', label="Record"), |
|
gr.Audio(sources="upload", type='filepath', label="Upload File") |
|
], |
|
outputs="text", |
|
title=title, |
|
description=description |
|
).launch() |