erastorgueva-nv commited on
Commit
8c427b0
·
1 Parent(s): 3732fef

take canary-1b code and update model name -> can do inference under 40 seconds without timestamps

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +326 -3
  3. packages.txt +2 -0
  4. pre-requirements.txt +1 -0
  5. requirements.txt +2 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Canary 1b Flash
3
- emoji: 📊
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: Canary 1b Flash
3
+ emoji: 🐤
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
app.py CHANGED
@@ -1,7 +1,330 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
  import gradio as gr
2
+ import json
3
+ import librosa
4
+ import os
5
+ import soundfile as sf
6
+ import tempfile
7
+ import uuid
8
 
9
+ import torch
 
10
 
11
+ from nemo.collections.asr.models import ASRModel
12
+ from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
13
+ from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
14
+
15
+ SAMPLE_RATE = 16000 # Hz
16
+ MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this
17
+
18
+ model = ASRModel.from_pretrained("nvidia/canary-1b-flash")
19
+ model.eval()
20
+
21
+ # make sure beam size always 1 for consistency
22
+ model.change_decoding_strategy(None)
23
+ decoding_cfg = model.cfg.decoding
24
+ decoding_cfg.beam.beam_size = 1
25
+ model.change_decoding_strategy(decoding_cfg)
26
+
27
+ # setup for buffered inference
28
+ model.cfg.preprocessor.dither = 0.0
29
+ model.cfg.preprocessor.pad_to = 0
30
+
31
+ feature_stride = model.cfg.preprocessor['window_stride']
32
+ model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
33
+
34
+ frame_asr = FrameBatchMultiTaskAED(
35
+ asr_model=model,
36
+ frame_len=40.0,
37
+ total_buffer=40.0,
38
+ batch_size=16,
39
+ )
40
+
41
+ amp_dtype = torch.float16
42
+
43
+ def convert_audio(audio_filepath, tmpdir, utt_id):
44
+ """
45
+ Convert all files to monochannel 16 kHz wav files.
46
+ Do not convert and raise error if audio too long.
47
+ Returns output filename and duration.
48
+ """
49
+
50
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
51
+
52
+ duration = librosa.get_duration(y=data, sr=sr)
53
+
54
+ if duration / 60.0 > MAX_AUDIO_MINUTES:
55
+ raise gr.Error(
56
+ f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. "
57
+ "If you wish, you may trim the audio using the Audio viewer in Step 1 "
58
+ "(click on the scissors icon to start trimming audio)."
59
+ )
60
+
61
+ if sr != SAMPLE_RATE:
62
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
63
+
64
+ out_filename = os.path.join(tmpdir, utt_id + '.wav')
65
+
66
+ # save output audio
67
+ sf.write(out_filename, data, SAMPLE_RATE)
68
+
69
+ return out_filename, duration
70
+
71
+
72
+ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
73
+
74
+ if audio_filepath is None:
75
+ raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
76
+
77
+ utt_id = uuid.uuid4()
78
+ with tempfile.TemporaryDirectory() as tmpdir:
79
+ converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
80
+
81
+ # map src_lang and tgt_lang from long versions to short
82
+ LANG_LONG_TO_LANG_SHORT = {
83
+ "English": "en",
84
+ "Spanish": "es",
85
+ "French": "fr",
86
+ "German": "de",
87
+ }
88
+ if src_lang not in LANG_LONG_TO_LANG_SHORT.keys():
89
+ raise ValueError(f"src_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}")
90
+ else:
91
+ src_lang = LANG_LONG_TO_LANG_SHORT[src_lang]
92
+
93
+ if tgt_lang not in LANG_LONG_TO_LANG_SHORT.keys():
94
+ raise ValueError(f"tgt_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}")
95
+ else:
96
+ tgt_lang = LANG_LONG_TO_LANG_SHORT[tgt_lang]
97
+
98
+
99
+ # infer taskname from src_lang and tgt_lang
100
+ if src_lang == tgt_lang:
101
+ taskname = "asr"
102
+ else:
103
+ taskname = "s2t_translation"
104
+
105
+ # update pnc variable to be "yes" or "no"
106
+ pnc = "yes" if pnc else "no"
107
+
108
+ # make manifest file and save
109
+ manifest_data = {
110
+ "audio_filepath": converted_audio_filepath,
111
+ "source_lang": src_lang,
112
+ "target_lang": tgt_lang,
113
+ "taskname": taskname,
114
+ "pnc": pnc,
115
+ "answer": "predict",
116
+ "duration": str(duration),
117
+ }
118
+
119
+ manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
120
+
121
+ with open(manifest_filepath, 'w') as fout:
122
+ line = json.dumps(manifest_data)
123
+ fout.write(line + '\n')
124
+
125
+ # call transcribe, passing in manifest filepath
126
+ if duration < 40:
127
+ output_text = model.transcribe(manifest_filepath)[0].text
128
+ else: # do buffered inference
129
+ with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
130
+ with torch.no_grad():
131
+ hyps = get_buffered_pred_feat_multitaskAED(
132
+ frame_asr,
133
+ model.cfg.preprocessor,
134
+ model_stride_in_secs,
135
+ model.device,
136
+ manifest=manifest_filepath,
137
+ filepaths=None,
138
+ )
139
+
140
+ output_text = hyps[0].text
141
+
142
+ return output_text
143
+
144
+ # add logic to make sure dropdown menus only suggest valid combos
145
+ def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value):
146
+ """Callback function for when src_lang or tgt_lang dropdown menus are changed.
147
+
148
+ Args:
149
+ src_lang_value(string), tgt_lang_value (string), pnc_value(bool) - the current
150
+ chosen "values" of each Gradio component
151
+ Returns:
152
+ src_lang, tgt_lang, pnc - these are the new Gradio components that will be displayed
153
+
154
+ Note: I found the required logic is easier to understand if you think about the possible src & tgt langs as
155
+ a matrix, e.g. with English, Spanish, French, German as the langs, and only transcription in the same language,
156
+ and X -> English and English -> X translation being allowed, the matrix looks like the diagram below ("Y" means it is
157
+ allowed to go into that state).
158
+ It is easier to understand the code if you think about which state you are in, given the current src_lang_value and
159
+ tgt_lang_value, and then which states you can go to from there.
160
+
161
+ tgt lang
162
+ - |EN |ES |FR |DE
163
+ ------------------
164
+ EN| Y | Y | Y | Y
165
+ ------------------
166
+ src ES| Y | Y | |
167
+ lang ------------------
168
+ FR| Y | | Y |
169
+ ------------------
170
+ DE| Y | | | Y
171
+ """
172
+
173
+ if src_lang_value == "English" and tgt_lang_value == "English":
174
+ # src_lang and tgt_lang can go anywhere
175
+ src_lang = gr.Dropdown(
176
+ choices=["English", "Spanish", "French", "German"],
177
+ value=src_lang_value,
178
+ label="Input audio is spoken in:"
179
+ )
180
+ tgt_lang = gr.Dropdown(
181
+ choices=["English", "Spanish", "French", "German"],
182
+ value=tgt_lang_value,
183
+ label="Transcribe in language:"
184
+ )
185
+ elif src_lang_value == "English":
186
+ # src is English & tgt is non-English
187
+ # => src can only be English or current tgt_lang_values
188
+ # & tgt can be anything
189
+ src_lang = gr.Dropdown(
190
+ choices=["English", tgt_lang_value],
191
+ value=src_lang_value,
192
+ label="Input audio is spoken in:"
193
+ )
194
+ tgt_lang = gr.Dropdown(
195
+ choices=["English", "Spanish", "French", "German"],
196
+ value=tgt_lang_value,
197
+ label="Transcribe in language:"
198
+ )
199
+ elif tgt_lang_value == "English":
200
+ # src is non-English & tgt is English
201
+ # => src can be anything
202
+ # & tgt can only be English or current src_lang_value
203
+ src_lang = gr.Dropdown(
204
+ choices=["English", "Spanish", "French", "German"],
205
+ value=src_lang_value,
206
+ label="Input audio is spoken in:"
207
+ )
208
+ tgt_lang = gr.Dropdown(
209
+ choices=["English", src_lang_value],
210
+ value=tgt_lang_value,
211
+ label="Transcribe in language:"
212
+ )
213
+ else:
214
+ # both src and tgt are non-English
215
+ # => both src and tgt can only be switch to English or themselves
216
+ src_lang = gr.Dropdown(
217
+ choices=["English", src_lang_value],
218
+ value=src_lang_value,
219
+ label="Input audio is spoken in:"
220
+ )
221
+ tgt_lang = gr.Dropdown(
222
+ choices=["English", tgt_lang_value],
223
+ value=tgt_lang_value,
224
+ label="Transcribe in language:"
225
+ )
226
+ # let pnc be anything if src_lang_value == tgt_lang_value, else fix to True
227
+ if src_lang_value == tgt_lang_value:
228
+ pnc = gr.Checkbox(
229
+ value=pnc_value,
230
+ label="Punctuation & Capitalization in transcript?",
231
+ interactive=True
232
+ )
233
+ else:
234
+ pnc = gr.Checkbox(
235
+ value=True,
236
+ label="Punctuation & Capitalization in transcript?",
237
+ interactive=False
238
+ )
239
+ return src_lang, tgt_lang, pnc
240
+
241
+
242
+ with gr.Blocks(
243
+ title="NeMo Canary 1B Flash Model",
244
+ css="""
245
+ textarea { font-size: 18px;}
246
+ #model_output_text_box span {
247
+ font-size: 18px;
248
+ font-weight: bold;
249
+ }
250
+ """,
251
+ theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
252
+ ) as demo:
253
+
254
+ gr.HTML("<h1 style='text-align: center'>NeMo Canary 1B Flash model: Transcribe & Translate audio</h1>")
255
+
256
+ with gr.Row():
257
+ with gr.Column():
258
+ gr.HTML(
259
+ "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
260
+
261
+ "<p style='color: #A0A0A0;'>This demo supports audio files up to 10 mins long. "
262
+ "You can transcribe longer files locally with this NeMo "
263
+ "<a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py'>script</a>.</p>"
264
+ )
265
+
266
+ audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
267
+
268
+ gr.HTML("<p><b>Step 2:</b> Choose the input and output language.</p>")
269
+
270
+ src_lang = gr.Dropdown(
271
+ choices=["English", "Spanish", "French", "German"],
272
+ value="English",
273
+ label="Input audio is spoken in:"
274
+ )
275
+
276
+ with gr.Column():
277
+ tgt_lang = gr.Dropdown(
278
+ choices=["English", "Spanish", "French", "German"],
279
+ value="English",
280
+ label="Transcribe in language:"
281
+ )
282
+ pnc = gr.Checkbox(
283
+ value=True,
284
+ label="Punctuation & Capitalization in transcript?",
285
+ )
286
+
287
+ with gr.Column():
288
+
289
+ gr.HTML("<p><b>Step 3:</b> Run the model.</p>")
290
+
291
+ go_button = gr.Button(
292
+ value="Run model",
293
+ variant="primary", # make "primary" so it stands out (default is "secondary")
294
+ )
295
+
296
+ model_output_text_box = gr.Textbox(
297
+ label="Model Output",
298
+ elem_id="model_output_text_box",
299
+ )
300
+
301
+ with gr.Row():
302
+
303
+ gr.HTML(
304
+ "<p style='text-align: center'>"
305
+ "🐤 <a href='https://huggingface.co/nvidia/canary-1b-flash' target='_blank'>Canary 1B Flash model</a> | "
306
+ "🧑‍💻 <a href='https://github.com/NVIDIA/NeMo' target='_blank'>NeMo Repository</a>"
307
+ "</p>"
308
+ )
309
+
310
+ go_button.click(
311
+ fn=transcribe,
312
+ inputs = [audio_file, src_lang, tgt_lang, pnc],
313
+ outputs = [model_output_text_box]
314
+ )
315
+
316
+ # call on_src_or_tgt_lang_change whenever src_lang or tgt_lang dropdown menus are changed
317
+ src_lang.change(
318
+ fn=on_src_or_tgt_lang_change,
319
+ inputs=[src_lang, tgt_lang, pnc],
320
+ outputs=[src_lang, tgt_lang, pnc],
321
+ )
322
+ tgt_lang.change(
323
+ fn=on_src_or_tgt_lang_change,
324
+ inputs=[src_lang, tgt_lang, pnc],
325
+ outputs=[src_lang, tgt_lang, pnc],
326
+ )
327
+
328
+
329
+ demo.queue()
330
  demo.launch()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Cython
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@6c229852a7d351b7cb5e0424ef23658cccd703f6 # using new PEP 508 syntax; recent version of main at time of writing
2
+ gradio==5.21.0 # latest version at time of writing