TunASR / app.py
brdhaker3's picture
Update app.py
75a78b0 verified
import torch
import gradio as gr
import speechbrain as sb
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from pyctcdecode import build_ctcdecoder
import os
# Load hyperparameters and initialize the ASR model
hparams_file = "train.yaml"
with open(hparams_file, "r") as fin:
hparams = load_hyperpyyaml(fin)
# Initialize the label encoder
encoder = sb.dataio.encoder.CTCTextEncoder()
encoder.load_or_create(
path=hparams["encoder_file"],
from_didatasets=[[]],
output_key="char_list",
special_labels=special_labels = {"blank_label":0,"unk_label": 1},
sequence_input=True,
)
# Prepare labels for the CTC decoder
ind2lab = encoder.ind2lab
labels = [ind2lab[x] for x in range(len(ind2lab))]
labels = [""] + labels[1:-1] + ["1"]
# Initialize the CTC decoder
decoder = build_ctcdecoder(
labels,
kenlm_model_path=hparams["ngram_lm_path"],
alpha=0.5,
beta=1.0,
)
# Define the ASR class with the `treat_wav` method
class ASR(sb.core.Brain):
def treat_wav(self, sig):
"""Process a waveform and return the transcribed text."""
feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
feats = self.modules.enc(feats)
logits = self.modules.ctc_lin(feats)
p_ctc = self.hparams.log_softmax(logits)
predicted_words = []
for logs in p_ctc:
text = decoder.decode(logs.detach().cpu().numpy())
predicted_words.append(text.split(" "))
return " ".join(predicted_words[0])
# Initialize the ASR model
asr_brain = ASR(
modules=hparams["modules"],
hparams=hparams,
run_opts={"device": "cpu"},
checkpointer=hparams["checkpointer"],
)
asr_brain.tokenizer = encoder
asr_brain.checkpointer.recover_if_possible()
asr_brain.modules.eval()
# Function to process audio files
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
if file_mic is not None:
wav = file_mic
elif file_upload is not None:
wav = file_upload
else:
return "ERROR: You have to either use the microphone or upload an audio file"
# Read and preprocess the audio file
info = torchaudio.info(wav)
sr = info.sample_rate
sig = sb.dataio.dataio.read_audio(wav)
if len(sig.shape) > 1:
sig = torch.mean(sig, dim=1)
sig = torch.unsqueeze(sig, 0)
tensor_wav = sig.to(device)
resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)
# Transcribe the audio
sentence = asr.treat_wav(resampled)
return sentence
# Gradio interface
title = "Tunisian Speech Recognition"
description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %.
\n
\n Interesting isn\'t it !'''
gr.Interface(
fn=treat_wav_file,
inputs=[
gr.Audio(sources="microphone", type='filepath', label="Record"),
gr.Audio(sources="upload", type='filepath', label="Upload File")
],
outputs="text",
title=title,
description=description
).launch()