File size: 3,316 Bytes
55bf388
 
 
 
4defacc
 
ed95978
55bf388
37c7fb5
 
55bf388
 
 
4defacc
 
 
 
 
a202342
f5a907f
a202342
4defacc
ed95978
 
 
 
 
 
 
 
 
 
 
37c7fb5
ed95978
37c7fb5
ed95978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4defacc
ed95978
 
4defacc
ed95978
 
 
 
 
 
 
 
55bf388
37c7fb5
55bf388
 
 
ed95978
 
55bf388
37c7fb5
55bf388
37c7fb5
 
 
 
 
 
 
 
55bf388
 
 
37c7fb5
 
 
 
 
 
55bf388
 
ed95978
37c7fb5
 
55bf388
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import gradio as gr
import torchaudio
from typing import Tuple, Optional
import soundfile as sf
from s2st_inference import s2st_inference
from utils import download_model

DESCRIPTION = r"**Speech-to-Speech Translation from Spanish to English**"

SAMPLE_RATE = 16000
MAX_INPUT_LENGTH = 60  # seconds

S2UT_TAG = 'espnet/jiyang_tang_cvss-c_es-en_discrete_unit'
S2UT_DIR = 'model'
VOCODER_TAG = 'espnet/cvss-c_en_wavegan_hubert_vocoder'
VOCODER_DIR = 'vocoder'

NGPU = 0
BEAM_SIZE = 1


class App:
    def __init__(self):
        # Download models
        os.makedirs(S2UT_DIR, exist_ok=True)
        os.makedirs(VOCODER_DIR, exist_ok=True)

        self.s2ut_path = download_model(S2UT_TAG, S2UT_DIR)
        self.vocoder_path = download_model(VOCODER_TAG, VOCODER_DIR)

    def s2st(
            self,
            input_audio: Optional[str],
    ):
        orig_wav, orig_sr = torchaudio.load(input_audio)
        wav = torchaudio.functional.resample(orig_wav, orig_freq=orig_sr, new_freq=SAMPLE_RATE)
        max_length = int(MAX_INPUT_LENGTH * SAMPLE_RATE)
        if wav.shape[1] > max_length:
            wav = wav[:, :max_length]
            gr.Warning(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.")

        wav = wav[0]  # mono

        # Temporary change cwd to model dir so that it loads correctly
        cwd = os.getcwd()
        os.chdir(self.s2ut_path)

        # Translate wav
        out_wav = s2st_inference(
            wav,
            train_config=os.path.join(
                self.s2ut_path,
                'exp',
                's2st_train_s2st_discrete_unit_raw_fbank_es_en',
                'config.yaml',
            ),
            model_file=os.path.join(
                self.s2ut_path,
                'exp',
                's2st_train_s2st_discrete_unit_raw_fbank_es_en',
                '500epoch.pth',
            ),
            vocoder_file=os.path.join(
                self.vocoder_path,
                'checkpoint-450000steps.pkl',
            ),
            vocoder_config=os.path.join(
                self.vocoder_path,
                'config.yml',
            ),
            ngpu=NGPU,
            beam_size=BEAM_SIZE,
        )

        # Restore working directory
        os.chdir(cwd)

        # Save result
        output_path = 'output.wav'
        sf.write(
            output_path,
            out_wav,
            16000,
            "PCM_16",
        )

        return output_path


def main():
    app = App()

    with gr.Blocks() as demo:
        gr.Markdown(DESCRIPTION)
        with gr.Group():
            input_audio = gr.Audio(
                label="Input speech",
                type="filepath",
                sources=["upload", "microphone"],
                format='wav',
                streaming=False,
                visible=True,
            )

            btn = gr.Button("Translate")

            output_audio = gr.Audio(
                label="Translated speech",
                autoplay=False,
                streaming=False,
                type="numpy",
            )

        btn.click(
            fn=app.s2st,
            inputs=[input_audio],
            outputs=[output_audio],
            api_name="run",
        )

        demo.queue(max_size=50).launch()


if __name__ == '__main__':
    main()