|
import gradio as gr |
|
import torchaudio |
|
from audiocraft.models import AudioGen |
|
from audiocraft.data.audio import audio_write |
|
|
|
model = AudioGen.get_pretrained('facebook/audiogen-medium') |
|
|
|
def infer(prompt, duration): |
|
|
|
model.set_generation_params(duration=duration) |
|
descriptions = [prompt] |
|
wav = model.generate(descriptions) |
|
|
|
for idx, one_wav in enumerate(wav): |
|
|
|
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) |
|
|
|
return "0.wav" |
|
|
|
css=""" |
|
#col-container{ |
|
margin: 0 auto; |
|
max-width: 640px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.HTML(""" |
|
<h2 style="text-align: center;"> |
|
AudioGen: Textually-guided audio generation |
|
</h2> |
|
<p style="text-align: center;"> |
|
</p> |
|
""") |
|
|
|
prompt_in = gr.Textbox(label="audio prompt") |
|
with gr.Row(): |
|
duration = gr.Slider(label="Duration", minimum=5, maximum=10, step=5, value=5) |
|
submit_btn = gr.Button("Submit") |
|
audio_o = gr.Audio(label="AudioGen result") |
|
|
|
submit_btn.click( |
|
fn=infer, |
|
inputs=[prompt_in, duration], |
|
outputs=[audio_o] |
|
) |
|
|
|
demo.queue().launch(debug=True) |