saeedzou commited on
Commit
cef39c8
·
verified ·
1 Parent(s): 9de87bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -53
app.py CHANGED
@@ -1,53 +1,154 @@
1
- import gradio as gr
2
- import nemo.collections.asr as nemo_asr
3
- from pydub import AudioSegment
4
- import os
5
- from huggingface_hub import login
6
-
7
- # Fetch the token from an environment variable
8
- HF_TOKEN = os.getenv("HF_TOKEN")
9
- if not HF_TOKEN:
10
- raise ValueError("HF_TOKEN environment variable not set. Please provide a valid Hugging Face token.")
11
-
12
- # Authenticate with Hugging Face
13
- login(HF_TOKEN)
14
-
15
- # Load the private NeMo ASR model
16
- try:
17
- asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(
18
- model_name="faimlab/stt_fa_fastconformer_hybrid_large_dataset_v30"
19
- )
20
- except Exception as e:
21
- raise RuntimeError(f"Failed to load model: {str(e)}")
22
-
23
- # Function to convert audio to 16kHz mono WAV
24
- def convert_to_wav(audio_path, output_path="temp.wav"):
25
- audio = AudioSegment.from_file(audio_path)
26
- audio = audio.set_channels(1).set_frame_rate(16000)
27
- audio.export(output_path, format="wav")
28
- return output_path
29
-
30
- # Transcription function
31
- def transcribe_audio(audio):
32
- if audio is None:
33
- return "Please upload an audio file."
34
-
35
- wav_path = convert_to_wav(audio)
36
- output = asr_model.transcribe([wav_path])
37
-
38
- if os.path.exists(wav_path):
39
- os.remove(wav_path)
40
-
41
- return output[0].text
42
-
43
- # Create Gradio interface
44
- interface = gr.Interface(
45
- fn=transcribe_audio,
46
- inputs=gr.Audio(type="filepath", label="Upload Audio File"),
47
- outputs=gr.Textbox(label="Transcription"),
48
- title="ASR Transcription with NeMo",
49
- description="Upload an audio file to transcribe it using a private NeMo ASR model."
50
- )
51
-
52
- # Launch the app
53
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import nemo.collections.asr as nemo_asr
3
+ from pydub import AudioSegment
4
+ import os
5
+ import yt_dlp as youtube_dl
6
+ from huggingface_hub import login
7
+ from hazm import Normalizer
8
+ import numpy as np
9
+ import re
10
+ import time
11
+
12
+ # Fetch the token from an environment variable
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+ if not HF_TOKEN:
15
+ raise ValueError("HF_TOKEN environment variable not set. Please provide a valid Hugging Face token.")
16
+
17
+ # Authenticate with Hugging Face
18
+ login(HF_TOKEN)
19
+
20
+ # Load the private NeMo ASR model
21
+ try:
22
+ asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(
23
+ model_name="faimlab/stt_fa_fastconformer_hybrid_large_dataset_v30"
24
+ )
25
+ except Exception as e:
26
+ raise RuntimeError(f"Failed to load model: {str(e)}")
27
+
28
+ normalizer = Normalizer()
29
+
30
+ def load_audio(audio_path):
31
+ audio = AudioSegment.from_file(audio_path)
32
+ audio = audio.set_channels(1).set_frame_rate(16000)
33
+ audio_samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
34
+ audio_samples /= np.max(np.abs(audio_samples))
35
+ return audio_samples, audio.frame_rate
36
+
37
+ def transcribe_chunk(audio_chunk, model):
38
+ transcription = model.transcribe([audio_chunk], batch_size=1, verbose=False)
39
+ return transcription[0].text
40
+
41
+ def transcribe_audio(file_path, model, chunk_size=30*16000):
42
+ waveform, _ = load_audio(file_path)
43
+ transcriptions = []
44
+ for start in range(0, len(waveform), chunk_size):
45
+ end = min(len(waveform), start + chunk_size)
46
+ transcription = transcribe_chunk(waveform[start:end], model)
47
+ transcriptions.append(transcription)
48
+
49
+ transcriptions = ' '.join(transcriptions)
50
+ transcriptions = re.sub(' +', ' ', transcriptions)
51
+ transcriptions = normalizer.normalize(transcriptions)
52
+
53
+ return transcriptions
54
+
55
+ # YouTube audio download function
56
+ YT_LENGTH_LIMIT_S = 3600
57
+
58
+ def download_yt_audio(yt_url, filename, cookie_file="cookies.txt"):
59
+ info_loader = youtube_dl.YoutubeDL()
60
+
61
+ try:
62
+ info = info_loader.extract_info(yt_url, download=False)
63
+ except youtube_dl.utils.DownloadError as err:
64
+ raise gr.Error(str(err))
65
+
66
+ file_length = info["duration_string"]
67
+ file_h_m_s = file_length.split(":")
68
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
69
+
70
+ if len(file_h_m_s) == 1:
71
+ file_h_m_s.insert(0, 0)
72
+ if len(file_h_m_s) == 2:
73
+ file_h_m_s.insert(0, 0)
74
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
75
+
76
+ if file_length_s > YT_LENGTH_LIMIT_S:
77
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
78
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
79
+ raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
80
+
81
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best", "cookies": cookie_file}
82
+
83
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
84
+ try:
85
+ ydl.download([yt_url])
86
+ except youtube_dl.utils.ExtractorError as err:
87
+ raise gr.Error(str(err))
88
+
89
+
90
+ # Gradio Interface
91
+ def transcribe(audio):
92
+ if audio is None:
93
+ return "Please upload an audio file."
94
+
95
+ transcription = transcribe_audio(audio, asr_model)
96
+
97
+ return transcription
98
+
99
+ def transcribe_yt(yt_url):
100
+ temp_filename = "/tmp/yt_audio.mp4" # Temporary filename for the downloaded video
101
+ download_yt_audio(yt_url, temp_filename)
102
+ transcription = transcribe_audio(temp_filename, asr_model)
103
+ return transcription
104
+
105
+ mf_transcribe = gr.Interface(
106
+ fn=transcribe,
107
+ inputs=gr.Microphone(type="filepath"),
108
+ outputs=gr.Textbox(label="Transcription"),
109
+ theme="huggingface",
110
+ title="Persian ASR Transcription with NeMo Fast Conformer",
111
+ description=(
112
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the NeMo's Fast Conformer Hybrid Large.\n\n"
113
+ "Trained on ~800 hours of Persian speech dataset (Common Voice 17 (~300 hours), YouTube (~400 hours), NasleMana (~90 hours), In-house dataset (~70 hours)).\n\n"
114
+ "For commercial applications, contact us via email: <[email protected]>.\n\n"
115
+ "Credit FAIM Group, Sharif University of Technology.\n\n"
116
+ ),
117
+ allow_flagging="never",
118
+ )
119
+
120
+ # File upload tab
121
+ file_transcribe = gr.Interface(
122
+ fn=transcribe,
123
+ inputs=gr.Audio(type="filepath", label="Audio file"),
124
+ outputs=gr.Textbox(label="Transcription"),
125
+ theme="huggingface",
126
+ title="Persian ASR Transcription with NeMo Fast Conformer",
127
+ description=(
128
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the NeMo's Fast Conformer Hybrid Large.\n\n"
129
+ "Trained on ~800 hours of Persian speech dataset (Common Voice 17 (~300 hours), YouTube (~400 hours), NasleMana (~90 hours), In-house dataset (~70 hours)).\n\n"
130
+ "For commercial applications, contact us via email: <[email protected]>.\n\n"
131
+ "Credit FAIM Group, Sharif University of Technology.\n\n"
132
+ ),
133
+ allow_flagging="never",
134
+ )
135
+
136
+ # YouTube tab
137
+ yt_transcribe = gr.Interface(
138
+ fn=transcribe_yt,
139
+ inputs=gr.Textbox(label="YouTube URL", placeholder="Enter the YouTube URL here"),
140
+ outputs=gr.Textbox(label="Transcription"),
141
+ theme="huggingface",
142
+ title="Transcribe YouTube Video",
143
+ description="Transcribe audio from a YouTube video by providing its URL. Currently YouTube is blocking the requests. So you will see the app showing error",
144
+ allow_flagging="never",
145
+ )
146
+
147
+ # Gradio Interface
148
+ demo = gr.Blocks()
149
+
150
+ with demo:
151
+ # Create the tabs with the list of interfaces
152
+ gr.TabbedInterface([mf_transcribe, file_transcribe, yt_transcribe], ["Microphone", "Audio file", "YouTube"])
153
+
154
+ demo.launch()