Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| import torchaudio | |
| from typing import Tuple, Optional | |
| import soundfile as sf | |
| from s2st_inference import s2st_inference | |
| from utils import download_model | |
| SAMPLE_RATE = 16000 | |
| MAX_INPUT_LENGTH = 60 # seconds | |
| S2UT_TAG = 'espnet/jiyang_tang_cvss-c_es-en_discrete_unit' | |
| S2UT_DIR = 'model' | |
| VOCODER_TAG = 'espnet/cvss-c_en_wavegan_hubert_vocoder' | |
| VOCODER_DIR = 'vocoder' | |
| NGPU = 0 | |
| BEAM_SIZE = 1 | |
| class App: | |
| def __init__(self): | |
| # Download models | |
| os.makedirs(S2UT_DIR, exist_ok=True) | |
| os.makedirs(VOCODER_DIR, exist_ok=True) | |
| self.s2ut_path = download_model(S2UT_TAG, S2UT_DIR) | |
| self.vocoder_path = download_model(VOCODER_TAG, VOCODER_DIR) | |
| def s2st( | |
| self, | |
| audio_source: str, | |
| input_audio_mic: Optional[str], | |
| input_audio_file: Optional[str], | |
| ): | |
| if audio_source == 'file': | |
| input_path = input_audio_file | |
| else: | |
| input_path = input_audio_mic | |
| if input_path is None: | |
| gr.Error(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.") | |
| return (None, None), None | |
| orig_wav, orig_sr = torchaudio.load(input_path) | |
| wav = torchaudio.functional.resample(orig_wav, orig_freq=orig_sr, new_freq=SAMPLE_RATE) | |
| max_length = int(MAX_INPUT_LENGTH * SAMPLE_RATE) | |
| if wav.shape[1] > max_length: | |
| wav = wav[:, :max_length] | |
| gr.Warning(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.") | |
| wav = wav[0] # mono | |
| # Temporary change cwd to model dir so that it loads correctly | |
| cwd = os.getcwd() | |
| os.chdir(self.s2ut_path) | |
| # Translate wav | |
| out_wav = s2st_inference( | |
| wav, | |
| train_config=os.path.join( | |
| self.s2ut_path, | |
| 'exp', | |
| 's2st_train_s2st_discrete_unit_raw_fbank_es_en', | |
| 'config.yaml', | |
| ), | |
| model_file=os.path.join( | |
| self.s2ut_path, | |
| 'exp', | |
| 's2st_train_s2st_discrete_unit_raw_fbank_es_en', | |
| '500epoch.pth', | |
| ), | |
| vocoder_file=os.path.join( | |
| self.vocoder_path, | |
| 'checkpoint-450000steps.pkl', | |
| ), | |
| vocoder_config=os.path.join( | |
| self.vocoder_path, | |
| 'config.yml', | |
| ), | |
| ngpu=NGPU, | |
| beam_size=BEAM_SIZE, | |
| ) | |
| # Restore working directory | |
| os.chdir(cwd) | |
| # Save result | |
| output_path = 'output.wav' | |
| sf.write( | |
| output_path, | |
| out_wav, | |
| 16000, | |
| "PCM_16", | |
| ) | |
| return output_path, f'Source: {audio_source}' | |
| def update_audio_ui(audio_source: str) -> Tuple[dict, dict]: | |
| mic = audio_source == "microphone" | |
| return ( | |
| gr.update(visible=mic, value=None), # input_audio_mic | |
| gr.update(visible=not mic, value=None), # input_audio_file | |
| ) | |
| def main(): | |
| app = App() | |
| with gr.Blocks() as demo: | |
| with gr.Group(): | |
| with gr.Row() as audio_box: | |
| audio_source = gr.Radio( | |
| label="Audio source", | |
| choices=["file", "microphone"], | |
| value="file", | |
| ) | |
| input_audio_mic = gr.Audio( | |
| label="Input speech", | |
| type="filepath", | |
| source="microphone", | |
| visible=False, | |
| ) | |
| input_audio_file = gr.Audio( | |
| label="Input speech", | |
| type="filepath", | |
| source="upload", | |
| visible=True, | |
| ) | |
| btn = gr.Button("Translate") | |
| with gr.Column(): | |
| output_audio = gr.Audio( | |
| label="Translated speech", | |
| autoplay=False, | |
| streaming=False, | |
| type="numpy", | |
| ) | |
| output_text = gr.Textbox(label="Translated text") | |
| audio_source.change( | |
| fn=update_audio_ui, | |
| inputs=audio_source, | |
| outputs=[ | |
| input_audio_mic, | |
| input_audio_file, | |
| ], | |
| queue=False, | |
| api_name=False, | |
| ) | |
| btn.click( | |
| fn=app.s2st, | |
| inputs=[ | |
| audio_source, | |
| input_audio_mic, | |
| input_audio_file, | |
| ], | |
| outputs=[output_audio, output_text], | |
| api_name="run", | |
| ) | |
| demo.queue(max_size=50).launch() | |
| if __name__ == '__main__': | |
| main() | |