File size: 3,135 Bytes
75a78b0 df657b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch
import gradio as gr
import speechbrain as sb
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from pyctcdecode import build_ctcdecoder
import os
# Load hyperparameters and initialize the ASR model
hparams_file = "train.yaml"
with open(hparams_file, "r") as fin:
hparams = load_hyperpyyaml(fin)
# Initialize the label encoder
encoder = sb.dataio.encoder.CTCTextEncoder()
encoder.load_or_create(
path=hparams["encoder_file"],
from_didatasets=[[]],
output_key="char_list",
special_labels=special_labels = {"blank_label":0,"unk_label": 1},
sequence_input=True,
)
# Prepare labels for the CTC decoder
ind2lab = encoder.ind2lab
labels = [ind2lab[x] for x in range(len(ind2lab))]
labels = [""] + labels[1:-1] + ["1"]
# Initialize the CTC decoder
decoder = build_ctcdecoder(
labels,
kenlm_model_path=hparams["ngram_lm_path"],
alpha=0.5,
beta=1.0,
)
# Define the ASR class with the `treat_wav` method
class ASR(sb.core.Brain):
def treat_wav(self, sig):
"""Process a waveform and return the transcribed text."""
feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
feats = self.modules.enc(feats)
logits = self.modules.ctc_lin(feats)
p_ctc = self.hparams.log_softmax(logits)
predicted_words = []
for logs in p_ctc:
text = decoder.decode(logs.detach().cpu().numpy())
predicted_words.append(text.split(" "))
return " ".join(predicted_words[0])
# Initialize the ASR model
asr_brain = ASR(
modules=hparams["modules"],
hparams=hparams,
run_opts={"device": "cpu"},
checkpointer=hparams["checkpointer"],
)
asr_brain.tokenizer = encoder
asr_brain.checkpointer.recover_if_possible()
asr_brain.modules.eval()
# Function to process audio files
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
if file_mic is not None:
wav = file_mic
elif file_upload is not None:
wav = file_upload
else:
return "ERROR: You have to either use the microphone or upload an audio file"
# Read and preprocess the audio file
info = torchaudio.info(wav)
sr = info.sample_rate
sig = sb.dataio.dataio.read_audio(wav)
if len(sig.shape) > 1:
sig = torch.mean(sig, dim=1)
sig = torch.unsqueeze(sig, 0)
tensor_wav = sig.to(device)
resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)
# Transcribe the audio
sentence = asr.treat_wav(resampled)
return sentence
# Gradio interface
title = "Tunisian Speech Recognition"
description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %.
\n
\n Interesting isn\'t it !'''
gr.Interface(
fn=treat_wav_file,
inputs=[
gr.Audio(sources="microphone", type='filepath', label="Record"),
gr.Audio(sources="upload", type='filepath', label="Upload File")
],
outputs="text",
title=title,
description=description
).launch() |