File size: 3,135 Bytes
75a78b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df657b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import gradio as gr
import speechbrain as sb
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from pyctcdecode import build_ctcdecoder
import os

# Load hyperparameters and initialize the ASR model
hparams_file = "train.yaml"
with open(hparams_file, "r") as fin:
    hparams = load_hyperpyyaml(fin)

# Initialize the label encoder
encoder = sb.dataio.encoder.CTCTextEncoder()

encoder.load_or_create(
    path=hparams["encoder_file"],
    from_didatasets=[[]],
    output_key="char_list",
    special_labels=special_labels = {"blank_label":0,"unk_label": 1},
    sequence_input=True,
)

# Prepare labels for the CTC decoder
ind2lab = encoder.ind2lab
labels = [ind2lab[x] for x in range(len(ind2lab))]
labels = [""] + labels[1:-1] + ["1"]

# Initialize the CTC decoder
decoder = build_ctcdecoder(
    labels,
    kenlm_model_path=hparams["ngram_lm_path"],
    alpha=0.5,
    beta=1.0,
)


# Define the ASR class with the `treat_wav` method
class ASR(sb.core.Brain):
    def treat_wav(self, sig):
        """Process a waveform and return the transcribed text."""
        feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
        feats = self.modules.enc(feats)
        logits = self.modules.ctc_lin(feats)
        p_ctc = self.hparams.log_softmax(logits)
        predicted_words = []
        for logs in p_ctc:
            text = decoder.decode(logs.detach().cpu().numpy())
            predicted_words.append(text.split(" "))
        return " ".join(predicted_words[0])


# Initialize the ASR model
asr_brain = ASR(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts={"device": "cpu"},
    checkpointer=hparams["checkpointer"],
)
asr_brain.tokenizer = encoder
asr_brain.checkpointer.recover_if_possible()
asr_brain.modules.eval()


# Function to process audio files
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
    if file_mic is not None:
        wav = file_mic
    elif file_upload is not None:
        wav = file_upload
    else:
        return "ERROR: You have to either use the microphone or upload an audio file"

    # Read and preprocess the audio file
    info = torchaudio.info(wav)
    sr = info.sample_rate
    sig = sb.dataio.dataio.read_audio(wav)
    if len(sig.shape) > 1:
        sig = torch.mean(sig, dim=1)
    sig = torch.unsqueeze(sig, 0)
    tensor_wav = sig.to(device)
    resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)

    # Transcribe the audio
    sentence = asr.treat_wav(resampled)
    return sentence


# Gradio interface
title = "Tunisian Speech Recognition"
description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %.
\n                                       
                                        \n Interesting isn\'t it !'''

gr.Interface(
    fn=treat_wav_file,
    inputs=[
        gr.Audio(sources="microphone", type='filepath', label="Record"),
        gr.Audio(sources="upload", type='filepath', label="Upload File")
    ],
    outputs="text",
    title=title,
    description=description
).launch()