Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
581b819
1
Parent(s):
674db8f
gr.Error doesnt display in function decorated by spaces.GPU - workaround by making decorated function smaller
Browse files
app.py
CHANGED
@@ -63,7 +63,40 @@ def convert_audio(audio_filepath, tmpdir, utt_id):
|
|
63 |
return out_filename, duration
|
64 |
|
65 |
@spaces.GPU
|
66 |
-
def transcribe(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
if audio_filepath is None:
|
69 |
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
|
@@ -149,25 +182,7 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc, gen_ts):
|
|
149 |
|
150 |
|
151 |
if gen_ts == "yes": # if will generate timestamps
|
152 |
-
|
153 |
-
if duration < 10:
|
154 |
-
output = model.transcribe(manifest_filepath)
|
155 |
-
else:
|
156 |
-
frame_asr = FrameBatchMultiTaskAED(
|
157 |
-
asr_model=model,
|
158 |
-
frame_len=10.0,
|
159 |
-
total_buffer=10.0,
|
160 |
-
batch_size=16,
|
161 |
-
)
|
162 |
-
|
163 |
-
output = get_buffered_pred_feat_multitaskAED(
|
164 |
-
frame_asr,
|
165 |
-
model.cfg.preprocessor,
|
166 |
-
model_stride_in_secs,
|
167 |
-
model.device,
|
168 |
-
manifest=manifest_filepath,
|
169 |
-
filepaths=None,
|
170 |
-
)
|
171 |
|
172 |
# process output to get word and segment level timestamps
|
173 |
word_level_timestamps = output[0].timestamp["word"]
|
@@ -186,25 +201,7 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc, gen_ts):
|
|
186 |
output_html += "</div>\n"
|
187 |
|
188 |
else: # if will not generate timestamps
|
189 |
-
|
190 |
-
if duration < 40:
|
191 |
-
output = model.transcribe(manifest_filepath)
|
192 |
-
|
193 |
-
else: # do buffered inference
|
194 |
-
frame_asr = FrameBatchMultiTaskAED(
|
195 |
-
asr_model=model,
|
196 |
-
frame_len=40.0,
|
197 |
-
total_buffer=40.0,
|
198 |
-
batch_size=16,
|
199 |
-
)
|
200 |
-
output = get_buffered_pred_feat_multitaskAED(
|
201 |
-
frame_asr,
|
202 |
-
model.cfg.preprocessor,
|
203 |
-
model_stride_in_secs,
|
204 |
-
model.device,
|
205 |
-
manifest=manifest_filepath,
|
206 |
-
filepaths=None,
|
207 |
-
)
|
208 |
|
209 |
if taskname == "asr":
|
210 |
output_html += "<div class='heading'>Transcript</div>\n"
|
@@ -403,7 +400,7 @@ with gr.Blocks(
|
|
403 |
)
|
404 |
|
405 |
go_button.click(
|
406 |
-
fn=
|
407 |
inputs = [audio_file, src_lang, tgt_lang, pnc, gen_ts],
|
408 |
outputs = [model_output_html]
|
409 |
)
|
|
|
63 |
return out_filename, duration
|
64 |
|
65 |
@spaces.GPU
|
66 |
+
def transcribe(manifest_filepath, model, model_stride_in_secs, audio_duration, duration_limit):
|
67 |
+
"""
|
68 |
+
Transcribe audio using either model.transcribe or buffered inference.
|
69 |
+
Duration limit determines which method to use and what chunk size will
|
70 |
+
be used in the case of buffered inference.
|
71 |
+
|
72 |
+
Note: I have observed that if you try to throw a gr.Error inside a function
|
73 |
+
decorated with @spaces.GPU, the error message you specified in gr.Error will
|
74 |
+
not be shown, instead it show the message "ZeroGPU worker error".
|
75 |
+
"""
|
76 |
+
|
77 |
+
if audio_duration < duration_limit:
|
78 |
+
|
79 |
+
output = model.transcribe(manifest_filepath)
|
80 |
+
|
81 |
+
else:
|
82 |
+
frame_asr = FrameBatchMultiTaskAED(
|
83 |
+
asr_model=model,
|
84 |
+
frame_len=duration_limit,
|
85 |
+
total_buffer=duration_limit,
|
86 |
+
batch_size=16,
|
87 |
+
)
|
88 |
+
output = get_buffered_pred_feat_multitaskAED(
|
89 |
+
frame_asr,
|
90 |
+
model.cfg.preprocessor,
|
91 |
+
model_stride_in_secs,
|
92 |
+
model.device,
|
93 |
+
manifest=manifest_filepath,
|
94 |
+
filepaths=None,
|
95 |
+
)
|
96 |
+
return output
|
97 |
+
|
98 |
+
|
99 |
+
def on_go_btn_click(audio_filepath, src_lang, tgt_lang, pnc, gen_ts):
|
100 |
|
101 |
if audio_filepath is None:
|
102 |
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
|
|
|
182 |
|
183 |
|
184 |
if gen_ts == "yes": # if will generate timestamps
|
185 |
+
output = transcribe(manifest_filepath, model, model_stride_in_secs, audio_duration=duration, duration_limit=10.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
# process output to get word and segment level timestamps
|
188 |
word_level_timestamps = output[0].timestamp["word"]
|
|
|
201 |
output_html += "</div>\n"
|
202 |
|
203 |
else: # if will not generate timestamps
|
204 |
+
output = transcribe(manifest_filepath, model, model_stride_in_secs, audio_duration=duration, duration_limit=40.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
if taskname == "asr":
|
207 |
output_html += "<div class='heading'>Transcript</div>\n"
|
|
|
400 |
)
|
401 |
|
402 |
go_button.click(
|
403 |
+
fn=on_go_btn_click,
|
404 |
inputs = [audio_file, src_lang, tgt_lang, pnc, gen_ts],
|
405 |
outputs = [model_output_html]
|
406 |
)
|