erastorgueva-nv commited on
Commit
581b819
·
1 Parent(s): 674db8f

gr.Error doesnt display in function decorated by spaces.GPU - workaround by making decorated function smaller

Browse files
Files changed (1) hide show
  1. app.py +37 -40
app.py CHANGED
@@ -63,7 +63,40 @@ def convert_audio(audio_filepath, tmpdir, utt_id):
63
  return out_filename, duration
64
 
65
  @spaces.GPU
66
- def transcribe(audio_filepath, src_lang, tgt_lang, pnc, gen_ts):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  if audio_filepath is None:
69
  raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
@@ -149,25 +182,7 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc, gen_ts):
149
 
150
 
151
  if gen_ts == "yes": # if will generate timestamps
152
-
153
- if duration < 10:
154
- output = model.transcribe(manifest_filepath)
155
- else:
156
- frame_asr = FrameBatchMultiTaskAED(
157
- asr_model=model,
158
- frame_len=10.0,
159
- total_buffer=10.0,
160
- batch_size=16,
161
- )
162
-
163
- output = get_buffered_pred_feat_multitaskAED(
164
- frame_asr,
165
- model.cfg.preprocessor,
166
- model_stride_in_secs,
167
- model.device,
168
- manifest=manifest_filepath,
169
- filepaths=None,
170
- )
171
 
172
  # process output to get word and segment level timestamps
173
  word_level_timestamps = output[0].timestamp["word"]
@@ -186,25 +201,7 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc, gen_ts):
186
  output_html += "</div>\n"
187
 
188
  else: # if will not generate timestamps
189
-
190
- if duration < 40:
191
- output = model.transcribe(manifest_filepath)
192
-
193
- else: # do buffered inference
194
- frame_asr = FrameBatchMultiTaskAED(
195
- asr_model=model,
196
- frame_len=40.0,
197
- total_buffer=40.0,
198
- batch_size=16,
199
- )
200
- output = get_buffered_pred_feat_multitaskAED(
201
- frame_asr,
202
- model.cfg.preprocessor,
203
- model_stride_in_secs,
204
- model.device,
205
- manifest=manifest_filepath,
206
- filepaths=None,
207
- )
208
 
209
  if taskname == "asr":
210
  output_html += "<div class='heading'>Transcript</div>\n"
@@ -403,7 +400,7 @@ with gr.Blocks(
403
  )
404
 
405
  go_button.click(
406
- fn=transcribe,
407
  inputs = [audio_file, src_lang, tgt_lang, pnc, gen_ts],
408
  outputs = [model_output_html]
409
  )
 
63
  return out_filename, duration
64
 
65
  @spaces.GPU
66
+ def transcribe(manifest_filepath, model, model_stride_in_secs, audio_duration, duration_limit):
67
+ """
68
+ Transcribe audio using either model.transcribe or buffered inference.
69
+ Duration limit determines which method to use and what chunk size will
70
+ be used in the case of buffered inference.
71
+
72
+ Note: I have observed that if you try to throw a gr.Error inside a function
73
+ decorated with @spaces.GPU, the error message you specified in gr.Error will
74
+ not be shown, instead it show the message "ZeroGPU worker error".
75
+ """
76
+
77
+ if audio_duration < duration_limit:
78
+
79
+ output = model.transcribe(manifest_filepath)
80
+
81
+ else:
82
+ frame_asr = FrameBatchMultiTaskAED(
83
+ asr_model=model,
84
+ frame_len=duration_limit,
85
+ total_buffer=duration_limit,
86
+ batch_size=16,
87
+ )
88
+ output = get_buffered_pred_feat_multitaskAED(
89
+ frame_asr,
90
+ model.cfg.preprocessor,
91
+ model_stride_in_secs,
92
+ model.device,
93
+ manifest=manifest_filepath,
94
+ filepaths=None,
95
+ )
96
+ return output
97
+
98
+
99
+ def on_go_btn_click(audio_filepath, src_lang, tgt_lang, pnc, gen_ts):
100
 
101
  if audio_filepath is None:
102
  raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
 
182
 
183
 
184
  if gen_ts == "yes": # if will generate timestamps
185
+ output = transcribe(manifest_filepath, model, model_stride_in_secs, audio_duration=duration, duration_limit=10.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  # process output to get word and segment level timestamps
188
  word_level_timestamps = output[0].timestamp["word"]
 
201
  output_html += "</div>\n"
202
 
203
  else: # if will not generate timestamps
204
+ output = transcribe(manifest_filepath, model, model_stride_in_secs, audio_duration=duration, duration_limit=40.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  if taskname == "asr":
207
  output_html += "<div class='heading'>Transcript</div>\n"
 
400
  )
401
 
402
  go_button.click(
403
+ fn=on_go_btn_click,
404
  inputs = [audio_file, src_lang, tgt_lang, pnc, gen_ts],
405
  outputs = [model_output_html]
406
  )