Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import json | |
import librosa | |
import os | |
import soundfile as sf | |
import tempfile | |
import uuid | |
import torch | |
from nemo.collections.asr.models import ASRModel | |
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED | |
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED | |
SAMPLE_RATE = 16000 # Hz | |
MAX_AUDIO_MINUTES = 30 # wont try to transcribe if longer than this | |
model = ASRModel.from_pretrained("nvidia/canary-1b-flash") | |
model.eval() | |
# make sure beam size always 1 for consistency | |
model.change_decoding_strategy(None) | |
decoding_cfg = model.cfg.decoding | |
decoding_cfg.beam.beam_size = 1 | |
model.change_decoding_strategy(decoding_cfg) | |
# setup for buffered inference | |
model.cfg.preprocessor.dither = 0.0 | |
model.cfg.preprocessor.pad_to = 0 | |
feature_stride = model.cfg.preprocessor['window_stride'] | |
model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer | |
amp_dtype = torch.float16 | |
def convert_audio(audio_filepath, tmpdir, utt_id): | |
""" | |
Convert all files to monochannel 16 kHz wav files. | |
Do not convert and raise error if audio too long. | |
Returns output filename and duration. | |
""" | |
data, sr = librosa.load(audio_filepath, sr=None, mono=True) | |
duration = librosa.get_duration(y=data, sr=sr) | |
if duration / 60.0 > MAX_AUDIO_MINUTES: | |
raise gr.Error( | |
f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. " | |
"If you wish, you may trim the audio using the Audio viewer in Step 1 " | |
"(click on the scissors icon to start trimming audio)." | |
) | |
if sr != SAMPLE_RATE: | |
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) | |
out_filename = os.path.join(tmpdir, utt_id + '.wav') | |
# save output audio | |
sf.write(out_filename, data, SAMPLE_RATE) | |
return out_filename, duration | |
def transcribe(manifest_filepath, audio_duration, duration_limit): | |
""" | |
Transcribe audio using either model.transcribe or buffered inference. | |
Duration limit determines which method to use and what chunk size will | |
be used in the case of buffered inference. | |
Note: I have observed that if you try to throw a gr.Error inside a function | |
decorated with @spaces.GPU, the error message you specified in gr.Error will | |
not be shown, instead it show the message "ZeroGPU worker error". | |
""" | |
if audio_duration < duration_limit: | |
output = model.transcribe(manifest_filepath) | |
else: | |
frame_asr = FrameBatchMultiTaskAED( | |
asr_model=model, | |
frame_len=duration_limit, | |
total_buffer=duration_limit, | |
batch_size=16, | |
) | |
output = get_buffered_pred_feat_multitaskAED( | |
frame_asr, | |
model.cfg.preprocessor, | |
model_stride_in_secs, | |
model.device, | |
manifest=manifest_filepath, | |
filepaths=None, | |
) | |
return output | |
def on_go_btn_click(audio_filepath, src_lang, tgt_lang, pnc, gen_ts): | |
if audio_filepath is None: | |
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone") | |
utt_id = uuid.uuid4() | |
with tempfile.TemporaryDirectory() as tmpdir: | |
converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id)) | |
# map src_lang and tgt_lang from long versions to short | |
LANG_LONG_TO_LANG_SHORT = { | |
"English": "en", | |
"Spanish": "es", | |
"French": "fr", | |
"German": "de", | |
} | |
if src_lang not in LANG_LONG_TO_LANG_SHORT.keys(): | |
raise ValueError(f"src_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}") | |
else: | |
src_lang = LANG_LONG_TO_LANG_SHORT[src_lang] | |
if tgt_lang not in LANG_LONG_TO_LANG_SHORT.keys(): | |
raise ValueError(f"tgt_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}") | |
else: | |
tgt_lang = LANG_LONG_TO_LANG_SHORT[tgt_lang] | |
# infer taskname from src_lang and tgt_lang | |
if src_lang == tgt_lang: | |
taskname = "asr" | |
else: | |
taskname = "s2t_translation" | |
# update pnc and gen_ts variables to be "yes" or "no" | |
pnc = "yes" if pnc else "no" | |
gen_ts = "yes" if gen_ts else "no" | |
# make manifest file and save | |
manifest_data = { | |
"audio_filepath": converted_audio_filepath, | |
"source_lang": src_lang, | |
"target_lang": tgt_lang, | |
"taskname": taskname, | |
"pnc": pnc, | |
"answer": "predict", | |
"duration": str(duration), | |
"timestamp": gen_ts, | |
} | |
manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json') | |
with open(manifest_filepath, 'w') as fout: | |
line = json.dumps(manifest_data) | |
fout.write(line + '\n') | |
# setup beginning of output html | |
output_html = ''' | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<style> | |
.transcript { | |
font-family: Arial, sans-serif; | |
line-height: 1.6; | |
margin: 20px 0; | |
} | |
.timestamp { | |
color: gray; | |
font-size: 0.8em; | |
margin-right: 5px; | |
} | |
.heading { | |
color: #2c3e50; | |
font-family: Arial, sans-serif; | |
font-weight: bold; | |
margin: 15px 0 8px 0; | |
border-bottom: 1px solid #eee; | |
} | |
</style> | |
</head> | |
<body> | |
''' | |
if gen_ts == "yes": # if will generate timestamps | |
output = transcribe(manifest_filepath, audio_duration=duration, duration_limit=10.0) | |
# process output to get word and segment level timestamps | |
word_level_timestamps = output[0].timestamp["word"] | |
output_html += "<div class='heading'>Transcript with word-level timestamps (in seconds)</div>\n" | |
output_html += "<div class='transcript'>\n" | |
for entry in word_level_timestamps: | |
output_html += f'<span>{entry["word"]} <span class="timestamp">({entry["start"]:.2f}-{entry["end"]:.2f})</span></span>\n' | |
output_html += "</div>\n" | |
segment_level_timestamps = output[0].timestamp["segment"] | |
output_html += "<div class='heading'>Transcript with segment-level timestamps (in seconds)</div>\n" | |
output_html += "<div class='transcript'>\n" | |
for entry in segment_level_timestamps: | |
output_html += f'<span>{entry["segment"]} <span class="timestamp">({entry["start"]:.2f}-{entry["end"]:.2f})</span></span><br>\n' | |
output_html += "</div>\n" | |
else: # if will not generate timestamps | |
output = transcribe(manifest_filepath, audio_duration=duration, duration_limit=40.0) | |
if taskname == "asr": | |
output_html += "<div class='heading'>Transcript</div>\n" | |
else: | |
output_html += "<div class='heading'>Translated Text</div>\n" | |
output_text = output[0].text | |
output_html += f'<div class="transcript">{output_text}</div>\n' | |
output_html += ''' | |
</div> | |
</body> | |
</html> | |
''' | |
return output_html | |
# add logic to make sure dropdown menus only suggest valid combos | |
def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value, gen_ts_value): | |
"""Callback function for when src_lang or tgt_lang dropdown menus are changed. | |
Args: | |
src_lang_value(string), tgt_lang_value (string), pnc_value(bool), gen_ts_value(bool) - the current | |
chosen "values" of each Gradio component | |
Returns: | |
src_lang, tgt_lang, pnc, gen_ts - these are the new Gradio components that will be displayed | |
Note: I found the required logic is easier to understand if you think about the possible src & tgt langs as | |
a matrix, e.g. with English, Spanish, French, German as the langs, and only transcription in the same language, | |
and X -> English and English -> X translation being allowed, the matrix looks like the diagram below ("Y" means it is | |
allowed to go into that state). | |
It is easier to understand the code if you think about which state you are in, given the current src_lang_value and | |
tgt_lang_value, and then which states you can go to from there. | |
tgt lang | |
- |EN |ES |FR |DE | |
------------------ | |
EN| Y | Y | Y | Y | |
------------------ | |
src ES| Y | Y | | | |
lang ------------------ | |
FR| Y | | Y | | |
------------------ | |
DE| Y | | | Y | |
""" | |
if src_lang_value == "English" and tgt_lang_value == "English": | |
# src_lang and tgt_lang can go anywhere | |
src_lang = gr.Dropdown( | |
choices=["English", "Spanish", "French", "German"], | |
value=src_lang_value, | |
label="Input audio is spoken in:" | |
) | |
tgt_lang = gr.Dropdown( | |
choices=["English", "Spanish", "French", "German"], | |
value=tgt_lang_value, | |
label="Transcribe in language:" | |
) | |
elif src_lang_value == "English": | |
# src is English & tgt is non-English | |
# => src can only be English or current tgt_lang_values | |
# & tgt can be anything | |
src_lang = gr.Dropdown( | |
choices=["English", tgt_lang_value], | |
value=src_lang_value, | |
label="Input audio is spoken in:" | |
) | |
tgt_lang = gr.Dropdown( | |
choices=["English", "Spanish", "French", "German"], | |
value=tgt_lang_value, | |
label="Transcribe in language:" | |
) | |
elif tgt_lang_value == "English": | |
# src is non-English & tgt is English | |
# => src can be anything | |
# & tgt can only be English or current src_lang_value | |
src_lang = gr.Dropdown( | |
choices=["English", "Spanish", "French", "German"], | |
value=src_lang_value, | |
label="Input audio is spoken in:" | |
) | |
tgt_lang = gr.Dropdown( | |
choices=["English", src_lang_value], | |
value=tgt_lang_value, | |
label="Transcribe in language:" | |
) | |
else: | |
# both src and tgt are non-English | |
# => both src and tgt can only be switch to English or themselves | |
src_lang = gr.Dropdown( | |
choices=["English", src_lang_value], | |
value=src_lang_value, | |
label="Input audio is spoken in:" | |
) | |
tgt_lang = gr.Dropdown( | |
choices=["English", tgt_lang_value], | |
value=tgt_lang_value, | |
label="Transcribe in language:" | |
) | |
# if src_lang_value == tgt_lang_value then pnc and gen_ts can be anything | |
# else, fix pnc to True and gen_ts to False | |
if src_lang_value == tgt_lang_value: | |
pnc = gr.Checkbox( | |
value=pnc_value, | |
label="Punctuation & Capitalization in model output?", | |
interactive=True | |
) | |
gen_ts = gr.Checkbox( | |
value=gen_ts_value, | |
label="Generate timestamps?", | |
interactive=True | |
) | |
else: | |
pnc = gr.Checkbox( | |
value=True, | |
label="Punctuation & Capitalization in model output?", | |
interactive=False | |
) | |
gen_ts = gr.Checkbox( | |
value=False, | |
label="Generate timestamps?", | |
interactive=False | |
) | |
return src_lang, tgt_lang, pnc, gen_ts | |
with gr.Blocks( | |
title="NeMo Canary 1B Flash Model", | |
css=""" | |
textarea { font-size: 18px;} | |
""", | |
theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md ) | |
) as demo: | |
gr.HTML("<h1 style='text-align: center'>NeMo Canary 1B Flash model: Transcribe & Translate audio</h1>") | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML( | |
"<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>" | |
f"<p style='color: #A0A0A0;'>This demo supports audio files up to {MAX_AUDIO_MINUTES} mins long. " | |
"You can transcribe longer files locally with this NeMo " | |
"<a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py'>script</a>.</p>" | |
) | |
audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath") | |
gr.HTML( | |
"<p><b>Step 2:</b> Choose the input and output language.</p>" | |
"<p style='color: #A0A0A0;'>If input & output languages are the same, you can also toggle generating punctuation & capitalization and timestamps.</p>" | |
) | |
with gr.Column(): | |
src_lang = gr.Dropdown( | |
choices=["English", "Spanish", "French", "German"], | |
value="English", | |
label="Input audio is spoken in:" | |
) | |
tgt_lang = gr.Dropdown( | |
choices=["English", "Spanish", "French", "German"], | |
value="English", | |
label="Transcribe in language:" | |
) | |
pnc = gr.Checkbox( | |
value=True, | |
label="Punctuation & Capitalization in model output?", | |
) | |
gen_ts = gr.Checkbox( | |
value=False, | |
label="Generate timestamps?", | |
) | |
with gr.Column(): | |
gr.HTML("<p><b>Step 3:</b> Run the model.</p>") | |
go_button = gr.Button( | |
value="Run model", | |
variant="primary", # make "primary" so it stands out (default is "secondary") | |
) | |
model_output_html = gr.HTML( | |
# initialize with min-height to ensure "processing" animation will be visible | |
value='<div style="min-height: 100px;"></div>', | |
label="Model Output", | |
) | |
with gr.Row(): | |
gr.HTML( | |
"<p style='text-align: center'>" | |
"π€ <a href='https://huggingface.co/nvidia/canary-1b-flash' target='_blank'>Canary 1B Flash model</a> | " | |
"π§βπ» <a href='https://github.com/NVIDIA/NeMo' target='_blank'>NeMo Repository</a>" | |
"</p>" | |
) | |
go_button.click( | |
fn=on_go_btn_click, | |
inputs = [audio_file, src_lang, tgt_lang, pnc, gen_ts], | |
outputs = [model_output_html] | |
) | |
# call on_src_or_tgt_lang_change whenever src_lang or tgt_lang dropdown menus are changed | |
src_lang.change( | |
fn=on_src_or_tgt_lang_change, | |
inputs=[src_lang, tgt_lang, pnc, gen_ts], | |
outputs=[src_lang, tgt_lang, pnc, gen_ts], | |
) | |
tgt_lang.change( | |
fn=on_src_or_tgt_lang_change, | |
inputs=[src_lang, tgt_lang, pnc, gen_ts], | |
outputs=[src_lang, tgt_lang, pnc, gen_ts], | |
) | |
demo.queue() | |
demo.launch() | |