erastorgueva-nv commited on
Commit
b43c4a1
·
1 Parent(s): 9826ad7

add functionality to generate and display timestamps if transcribing

Browse files
Files changed (1) hide show
  1. app.py +135 -51
app.py CHANGED
@@ -14,7 +14,7 @@ from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTask
14
  from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
15
 
16
  SAMPLE_RATE = 16000 # Hz
17
- MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this
18
 
19
  model = ASRModel.from_pretrained("nvidia/canary-1b-flash")
20
  model.eval()
@@ -32,7 +32,14 @@ model.cfg.preprocessor.pad_to = 0
32
  feature_stride = model.cfg.preprocessor['window_stride']
33
  model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
34
 
35
- frame_asr = FrameBatchMultiTaskAED(
 
 
 
 
 
 
 
36
  asr_model=model,
37
  frame_len=40.0,
38
  total_buffer=40.0,
@@ -69,9 +76,8 @@ def convert_audio(audio_filepath, tmpdir, utt_id):
69
 
70
  return out_filename, duration
71
 
72
-
73
  @spaces.GPU
74
- def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
75
 
76
  if audio_filepath is None:
77
  raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
@@ -104,8 +110,9 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
104
  else:
105
  taskname = "s2t_translation"
106
 
107
- # update pnc variable to be "yes" or "no"
108
  pnc = "yes" if pnc else "no"
 
109
 
110
  # make manifest file and save
111
  manifest_data = {
@@ -116,6 +123,7 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
116
  "pnc": pnc,
117
  "answer": "predict",
118
  "duration": str(duration),
 
119
  }
120
 
121
  manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
@@ -124,34 +132,95 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
124
  line = json.dumps(manifest_data)
125
  fout.write(line + '\n')
126
 
127
- # call transcribe, passing in manifest filepath
128
- if duration < 40:
129
- output_text = model.transcribe(manifest_filepath)[0].text
130
- else: # do buffered inference
131
- with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
132
- with torch.no_grad():
133
- hyps = get_buffered_pred_feat_multitaskAED(
134
- frame_asr,
135
- model.cfg.preprocessor,
136
- model_stride_in_secs,
137
- model.device,
138
- manifest=manifest_filepath,
139
- filepaths=None,
140
- )
141
-
142
- output_text = hyps[0].text
143
-
144
- return output_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # add logic to make sure dropdown menus only suggest valid combos
147
- def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value):
148
  """Callback function for when src_lang or tgt_lang dropdown menus are changed.
149
 
150
  Args:
151
- src_lang_value(string), tgt_lang_value (string), pnc_value(bool) - the current
152
  chosen "values" of each Gradio component
153
  Returns:
154
- src_lang, tgt_lang, pnc - these are the new Gradio components that will be displayed
155
 
156
  Note: I found the required logic is easier to understand if you think about the possible src & tgt langs as
157
  a matrix, e.g. with English, Spanish, French, German as the langs, and only transcription in the same language,
@@ -225,30 +294,38 @@ def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value):
225
  value=tgt_lang_value,
226
  label="Transcribe in language:"
227
  )
228
- # let pnc be anything if src_lang_value == tgt_lang_value, else fix to True
 
229
  if src_lang_value == tgt_lang_value:
230
  pnc = gr.Checkbox(
231
  value=pnc_value,
232
- label="Punctuation & Capitalization in transcript?",
 
 
 
 
 
233
  interactive=True
234
  )
235
  else:
236
  pnc = gr.Checkbox(
237
  value=True,
238
- label="Punctuation & Capitalization in transcript?",
 
 
 
 
 
239
  interactive=False
240
  )
241
- return src_lang, tgt_lang, pnc
 
242
 
243
 
244
  with gr.Blocks(
245
  title="NeMo Canary 1B Flash Model",
246
  css="""
247
  textarea { font-size: 18px;}
248
- #model_output_text_box span {
249
- font-size: 18px;
250
- font-weight: bold;
251
- }
252
  """,
253
  theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
254
  ) as demo:
@@ -260,22 +337,25 @@ with gr.Blocks(
260
  gr.HTML(
261
  "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
262
 
263
- "<p style='color: #A0A0A0;'>This demo supports audio files up to 10 mins long. "
264
  "You can transcribe longer files locally with this NeMo "
265
  "<a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py'>script</a>.</p>"
266
  )
267
-
268
  audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
269
 
270
- gr.HTML("<p><b>Step 2:</b> Choose the input and output language.</p>")
271
-
272
- src_lang = gr.Dropdown(
273
- choices=["English", "Spanish", "French", "German"],
274
- value="English",
275
- label="Input audio is spoken in:"
276
  )
277
 
 
278
  with gr.Column():
 
 
 
 
 
 
279
  tgt_lang = gr.Dropdown(
280
  choices=["English", "Spanish", "French", "German"],
281
  value="English",
@@ -283,7 +363,11 @@ with gr.Blocks(
283
  )
284
  pnc = gr.Checkbox(
285
  value=True,
286
- label="Punctuation & Capitalization in transcript?",
 
 
 
 
287
  )
288
 
289
  with gr.Column():
@@ -295,11 +379,11 @@ with gr.Blocks(
295
  variant="primary", # make "primary" so it stands out (default is "secondary")
296
  )
297
 
298
- model_output_text_box = gr.Textbox(
299
  label="Model Output",
300
- elem_id="model_output_text_box",
301
  )
302
 
 
303
  with gr.Row():
304
 
305
  gr.HTML(
@@ -311,20 +395,20 @@ with gr.Blocks(
311
 
312
  go_button.click(
313
  fn=transcribe,
314
- inputs = [audio_file, src_lang, tgt_lang, pnc],
315
- outputs = [model_output_text_box]
316
  )
317
 
318
  # call on_src_or_tgt_lang_change whenever src_lang or tgt_lang dropdown menus are changed
319
  src_lang.change(
320
  fn=on_src_or_tgt_lang_change,
321
- inputs=[src_lang, tgt_lang, pnc],
322
- outputs=[src_lang, tgt_lang, pnc],
323
  )
324
  tgt_lang.change(
325
  fn=on_src_or_tgt_lang_change,
326
- inputs=[src_lang, tgt_lang, pnc],
327
- outputs=[src_lang, tgt_lang, pnc],
328
  )
329
 
330
 
 
14
  from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
15
 
16
  SAMPLE_RATE = 16000 # Hz
17
+ MAX_AUDIO_MINUTES = 30 # wont try to transcribe if longer than this
18
 
19
  model = ASRModel.from_pretrained("nvidia/canary-1b-flash")
20
  model.eval()
 
32
  feature_stride = model.cfg.preprocessor['window_stride']
33
  model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
34
 
35
+ frame_asr_10s = FrameBatchMultiTaskAED(
36
+ asr_model=model,
37
+ frame_len=10.0,
38
+ total_buffer=10.0,
39
+ batch_size=16,
40
+ )
41
+
42
+ frame_asr_40s = FrameBatchMultiTaskAED(
43
  asr_model=model,
44
  frame_len=40.0,
45
  total_buffer=40.0,
 
76
 
77
  return out_filename, duration
78
 
 
79
  @spaces.GPU
80
+ def transcribe(audio_filepath, src_lang, tgt_lang, pnc, gen_ts):
81
 
82
  if audio_filepath is None:
83
  raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
 
110
  else:
111
  taskname = "s2t_translation"
112
 
113
+ # update pnc and gen_ts variables to be "yes" or "no"
114
  pnc = "yes" if pnc else "no"
115
+ gen_ts = "yes" if gen_ts else "no"
116
 
117
  # make manifest file and save
118
  manifest_data = {
 
123
  "pnc": pnc,
124
  "answer": "predict",
125
  "duration": str(duration),
126
+ "timestamp": gen_ts,
127
  }
128
 
129
  manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
 
132
  line = json.dumps(manifest_data)
133
  fout.write(line + '\n')
134
 
135
+
136
+ # setup beginning of output html
137
+ output_html = '''
138
+ <!DOCTYPE html>
139
+ <html lang="en">
140
+ <head>
141
+ <style>
142
+
143
+ .transcript {
144
+ font-family: Arial, sans-serif;
145
+ line-height: 1.6;
146
+ }
147
+ .timestamp {
148
+ color: gray;
149
+ font-size: 0.8em;
150
+ margin-right: 5px;
151
+ }
152
+ </style>
153
+ </head>
154
+ <body>
155
+ '''
156
+
157
+
158
+ if gen_ts == "yes": # if will generate timestamps
159
+
160
+ if duration < 10:
161
+ output = model.transcribe(manifest_filepath)
162
+ else:
163
+ output = get_buffered_pred_feat_multitaskAED(
164
+ frame_asr_10s,
165
+ model.cfg.preprocessor,
166
+ model_stride_in_secs,
167
+ model.device,
168
+ manifest=manifest_filepath,
169
+ filepaths=None,
170
+ )
171
+
172
+ # process output to get word and segment level timestamps
173
+ word_level_timestamps = output[0].timestamp["word"]
174
+
175
+ output_html += "<p><b>Transcript with word-level timestamps (in seconds)</b></p>\n"
176
+ output_html += "<div class='transcript'>\n"
177
+ for entry in word_level_timestamps:
178
+ output_html += f'<span>{entry["word"]} <span class="timestamp">({entry["start"]:.2f}-{entry["end"]:.2f})</span></span>\n'
179
+ output_html += "</div>\n"
180
+
181
+ segment_level_timestamps = output[0].timestamp["segment"]
182
+ output_html += "<p><b>Transcript with segment-level timestamps (in seconds)</b></p>\n"
183
+ output_html += "<div class='transcript'>\n"
184
+ for entry in segment_level_timestamps:
185
+ output_html += f'<span>{entry["segment"]} <span class="timestamp">({entry["start"]:.2f}-{entry["end"]:.2f})</span></span>\n'
186
+ output_html += "</div>\n"
187
+
188
+ else: # if will not generate timestamps
189
+
190
+ if duration < 40:
191
+ output = model.transcribe(manifest_filepath)
192
+
193
+ else: # do buffered inference
194
+ output = get_buffered_pred_feat_multitaskAED(
195
+ frame_asr_40s,
196
+ model.cfg.preprocessor,
197
+ model_stride_in_secs,
198
+ model.device,
199
+ manifest=manifest_filepath,
200
+ filepaths=None,
201
+ )
202
+
203
+ output_html += "<p><b>Transcript</b></p>\n"
204
+ output_text = output[0].text
205
+ output_html += f'<div class="transcript">{output_text}</div>\n'
206
+
207
+ output_html += '''
208
+ </div>
209
+ </body>
210
+ </html>
211
+ '''
212
+
213
+ return output_html
214
 
215
  # add logic to make sure dropdown menus only suggest valid combos
216
+ def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value, gen_ts_value):
217
  """Callback function for when src_lang or tgt_lang dropdown menus are changed.
218
 
219
  Args:
220
+ src_lang_value(string), tgt_lang_value (string), pnc_value(bool), gen_ts_value(bool) - the current
221
  chosen "values" of each Gradio component
222
  Returns:
223
+ src_lang, tgt_lang, pnc, gen_ts - these are the new Gradio components that will be displayed
224
 
225
  Note: I found the required logic is easier to understand if you think about the possible src & tgt langs as
226
  a matrix, e.g. with English, Spanish, French, German as the langs, and only transcription in the same language,
 
294
  value=tgt_lang_value,
295
  label="Transcribe in language:"
296
  )
297
+ # if src_lang_value == tgt_lang_value then pnc and gen_ts can be anything
298
+ # else, fix pnc to True and gen_ts to False
299
  if src_lang_value == tgt_lang_value:
300
  pnc = gr.Checkbox(
301
  value=pnc_value,
302
+ label="Punctuation & Capitalization in model output?",
303
+ interactive=True
304
+ )
305
+ gen_ts = gr.Checkbox(
306
+ value=gen_ts_value,
307
+ label="Generate timestamps?",
308
  interactive=True
309
  )
310
  else:
311
  pnc = gr.Checkbox(
312
  value=True,
313
+ label="Punctuation & Capitalization in model output?",
314
+ interactive=False
315
+ )
316
+ gen_ts = gr.Checkbox(
317
+ value=False,
318
+ label="Generate timestamps?",
319
  interactive=False
320
  )
321
+
322
+ return src_lang, tgt_lang, pnc, gen_ts
323
 
324
 
325
  with gr.Blocks(
326
  title="NeMo Canary 1B Flash Model",
327
  css="""
328
  textarea { font-size: 18px;}
 
 
 
 
329
  """,
330
  theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
331
  ) as demo:
 
337
  gr.HTML(
338
  "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
339
 
340
+ f"<p style='color: #A0A0A0;'>This demo supports audio files up to {MAX_AUDIO_MINUTES} mins long. "
341
  "You can transcribe longer files locally with this NeMo "
342
  "<a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py'>script</a>.</p>"
343
  )
 
344
  audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
345
 
346
+ gr.HTML(
347
+ "<p><b>Step 2:</b> Choose the input and output language.</p>"
348
+ "<p style='color: #A0A0A0;'>If input & output languages are the same, you can also toggle generating punctuation & capitalization and timestamps.</p>"
 
 
 
349
  )
350
 
351
+
352
  with gr.Column():
353
+ src_lang = gr.Dropdown(
354
+ choices=["English", "Spanish", "French", "German"],
355
+ value="English",
356
+ label="Input audio is spoken in:"
357
+ )
358
+
359
  tgt_lang = gr.Dropdown(
360
  choices=["English", "Spanish", "French", "German"],
361
  value="English",
 
363
  )
364
  pnc = gr.Checkbox(
365
  value=True,
366
+ label="Punctuation & Capitalization in model output?",
367
+ )
368
+ gen_ts = gr.Checkbox(
369
+ value=True,
370
+ label="Generate timestamps?",
371
  )
372
 
373
  with gr.Column():
 
379
  variant="primary", # make "primary" so it stands out (default is "secondary")
380
  )
381
 
382
+ model_output_html = gr.HTML(
383
  label="Model Output",
 
384
  )
385
 
386
+
387
  with gr.Row():
388
 
389
  gr.HTML(
 
395
 
396
  go_button.click(
397
  fn=transcribe,
398
+ inputs = [audio_file, src_lang, tgt_lang, pnc, gen_ts],
399
+ outputs = [model_output_html]
400
  )
401
 
402
  # call on_src_or_tgt_lang_change whenever src_lang or tgt_lang dropdown menus are changed
403
  src_lang.change(
404
  fn=on_src_or_tgt_lang_change,
405
+ inputs=[src_lang, tgt_lang, pnc, gen_ts],
406
+ outputs=[src_lang, tgt_lang, pnc, gen_ts],
407
  )
408
  tgt_lang.change(
409
  fn=on_src_or_tgt_lang_change,
410
+ inputs=[src_lang, tgt_lang, pnc, gen_ts],
411
+ outputs=[src_lang, tgt_lang, pnc, gen_ts],
412
  )
413
 
414