import spaces import gradio as gr import json import librosa import os import soundfile as sf import tempfile import uuid import torch from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED SAMPLE_RATE = 16000 # Hz MAX_AUDIO_MINUTES = 30 # wont try to transcribe if longer than this model = ASRModel.from_pretrained("nvidia/canary-180m-flash") model.eval() # make sure beam size always 1 for consistency model.change_decoding_strategy(None) decoding_cfg = model.cfg.decoding decoding_cfg.beam.beam_size = 1 model.change_decoding_strategy(decoding_cfg) # setup for buffered inference model.cfg.preprocessor.dither = 0.0 model.cfg.preprocessor.pad_to = 0 feature_stride = model.cfg.preprocessor['window_stride'] model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer amp_dtype = torch.float16 def convert_audio(audio_filepath, tmpdir, utt_id): """ Convert all files to monochannel 16 kHz wav files. Do not convert and raise error if audio too long. Returns output filename and duration. """ data, sr = librosa.load(audio_filepath, sr=None, mono=True) duration = librosa.get_duration(y=data, sr=sr) if duration / 60.0 > MAX_AUDIO_MINUTES: raise gr.Error( f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. " "If you wish, you may trim the audio using the Audio viewer in Step 1 " "(click on the scissors icon to start trimming audio)." ) if sr != SAMPLE_RATE: data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) out_filename = os.path.join(tmpdir, utt_id + '.wav') # save output audio sf.write(out_filename, data, SAMPLE_RATE) return out_filename, duration @spaces.GPU def transcribe(manifest_filepath, audio_duration, duration_limit): """ Transcribe audio using either model.transcribe or buffered inference. Duration limit determines which method to use and what chunk size will be used in the case of buffered inference. Note: I have observed that if you try to throw a gr.Error inside a function decorated with @spaces.GPU, the error message you specified in gr.Error will not be shown, instead it show the message "ZeroGPU worker error". """ if audio_duration < duration_limit: output = model.transcribe(manifest_filepath) else: frame_asr = FrameBatchMultiTaskAED( asr_model=model, frame_len=duration_limit, total_buffer=duration_limit, batch_size=16, ) output = get_buffered_pred_feat_multitaskAED( frame_asr, model.cfg.preprocessor, model_stride_in_secs, model.device, manifest=manifest_filepath, filepaths=None, ) return output def on_go_btn_click(audio_filepath, src_lang, tgt_lang, pnc, gen_ts): if audio_filepath is None: raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone") utt_id = uuid.uuid4() with tempfile.TemporaryDirectory() as tmpdir: converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id)) # map src_lang and tgt_lang from long versions to short LANG_LONG_TO_LANG_SHORT = { "English": "en", "Spanish": "es", "French": "fr", "German": "de", } if src_lang not in LANG_LONG_TO_LANG_SHORT.keys(): raise ValueError(f"src_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}") else: src_lang = LANG_LONG_TO_LANG_SHORT[src_lang] if tgt_lang not in LANG_LONG_TO_LANG_SHORT.keys(): raise ValueError(f"tgt_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}") else: tgt_lang = LANG_LONG_TO_LANG_SHORT[tgt_lang] # infer taskname from src_lang and tgt_lang if src_lang == tgt_lang: taskname = "asr" else: taskname = "s2t_translation" # update pnc and gen_ts variables to be "yes" or "no" pnc = "yes" if pnc else "no" gen_ts = "yes" if gen_ts else "no" # make manifest file and save manifest_data = { "audio_filepath": converted_audio_filepath, "source_lang": src_lang, "target_lang": tgt_lang, "taskname": taskname, "pnc": pnc, "answer": "predict", "duration": str(duration), "timestamp": gen_ts, } manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json') with open(manifest_filepath, 'w') as fout: line = json.dumps(manifest_data) fout.write(line + '\n') # setup beginning of output html output_html = '''
''' if gen_ts == "yes": # if will generate timestamps output = transcribe(manifest_filepath, audio_duration=duration, duration_limit=10.0) # process output to get word and segment level timestamps word_level_timestamps = output[0].timestamp["word"] output_html += "Step 1: Upload an audio file or record with your microphone.
" f"This demo supports audio files up to {MAX_AUDIO_MINUTES} mins long. " "You can transcribe longer files locally with this NeMo " "script.
" ) audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath") gr.HTML( "Step 2: Choose the input and output language.
" "If input & output languages are the same, you can also toggle generating punctuation & capitalization and timestamps.
" ) with gr.Column(): src_lang = gr.Dropdown( choices=["English", "Spanish", "French", "German"], value="English", label="Input audio is spoken in:" ) tgt_lang = gr.Dropdown( choices=["English", "Spanish", "French", "German"], value="English", label="Transcribe in language:" ) pnc = gr.Checkbox( value=True, label="Punctuation & Capitalization in model output?", ) gen_ts = gr.Checkbox( value=False, label="Generate timestamps?", ) with gr.Column(): gr.HTML("Step 3: Run the model.
") go_button = gr.Button( value="Run model", variant="primary", # make "primary" so it stands out (default is "secondary") ) model_output_html = gr.HTML( # initialize with min-height to ensure "processing" animation will be visible value='', label="Model Output", ) with gr.Row(): gr.HTML( "" "🐤 Canary 1B Flash model | " "🧑💻 NeMo Repository" "
" ) go_button.click( fn=on_go_btn_click, inputs = [audio_file, src_lang, tgt_lang, pnc, gen_ts], outputs = [model_output_html] ) # call on_src_or_tgt_lang_change whenever src_lang or tgt_lang dropdown menus are changed src_lang.change( fn=on_src_or_tgt_lang_change, inputs=[src_lang, tgt_lang, pnc, gen_ts], outputs=[src_lang, tgt_lang, pnc, gen_ts], ) tgt_lang.change( fn=on_src_or_tgt_lang_change, inputs=[src_lang, tgt_lang, pnc, gen_ts], outputs=[src_lang, tgt_lang, pnc, gen_ts], ) demo.queue() demo.launch()