Spaces:
Paused
Paused
import gradio as gr | |
import spaces | |
import huggingface_hub | |
import numpy as np | |
import pandas as pd | |
import os | |
import shutil | |
import torch | |
from audiocraft.data.audio import audio_write | |
import audiocraft.models | |
# download models | |
huggingface_hub.hf_hub_download( | |
repo_id='Cyan0731/MusiConGen', | |
filename='compression_state_dict.bin', | |
local_dir='./ckpt/musicongen' | |
) | |
huggingface_hub.hf_hub_download( | |
repo_id='Cyan0731/MusiConGen', | |
filename='state_dict.bin', | |
local_dir='./ckpt/musicongen' | |
) | |
def print_directory_contents(path): | |
for root, dirs, files in os.walk(path): | |
level = root.replace(path, '').count(os.sep) | |
indent = ' ' * 4 * (level) | |
print(f"{indent}{os.path.basename(root)}/") | |
subindent = ' ' * 4 * (level + 1) | |
for f in files: | |
print(f"{subindent}{f}") | |
def check_outputs_folder(folder_path): | |
# Check if the folder exists | |
if os.path.exists(folder_path) and os.path.isdir(folder_path): | |
# Delete all contents inside the folder | |
for filename in os.listdir(folder_path): | |
file_path = os.path.join(folder_path, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) # Remove file or link | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) # Remove directory | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
else: | |
print(f'The folder {folder_path} does not exist.') | |
def check_for_wav_in_outputs(): | |
# Define the path to the outputs folder | |
outputs_folder = './output_samples/example_1' | |
# Check if the outputs folder exists | |
if not os.path.exists(outputs_folder): | |
return None | |
# Check if there is a .mp4 file in the outputs folder | |
mp4_files = [f for f in os.listdir(outputs_folder) if f.endswith('.wav')] | |
# Return the path to the mp4 file if it exists | |
if mp4_files: | |
return os.path.join(outputs_folder, mp4_files[0]) | |
else: | |
return None | |
def infer(prompt_in, chords, duration, bpms): | |
# check if 'outputs' dir exists and empty it if necessary | |
check_outputs_folder('./output_samples/example_1') | |
# set hparams | |
output_dir = 'example_1' ### change this output directory | |
duration = duration | |
num_samples = 1 | |
bs = 1 | |
# load your model | |
musicgen = audiocraft.models.MusicGen.get_pretrained('./ckpt/musicongen') ### change this path | |
musicgen.set_generation_params(duration=duration, extend_stride=duration//2, top_k = 250) | |
chords = [chords] | |
descriptions = [prompt_in] * num_samples | |
bpms = [bpms] * num_samples | |
meters = [4] * num_samples | |
wav = [] | |
for i in range(num_samples//bs): | |
print(f"starting {i} batch...") | |
temp = musicgen.generate_with_chords_and_beats( | |
descriptions[i*bs:(i+1)*bs], | |
chords[i*bs:(i+1)*bs], | |
bpms[i*bs:(i+1)*bs], | |
meters[i*bs:(i+1)*bs] | |
) | |
wav.extend(temp.cpu()) | |
# save and display generated audio | |
for idx, one_wav in enumerate(wav): | |
sav_path = os.path.join('./output_samples', output_dir, chords[idx] + "|" + descriptions[idx]).replace(" ", "_") | |
audio_write(sav_path, one_wav.cpu(), musicgen.sample_rate, strategy='loudness', loudness_compressor=True) | |
# Print the outputs directory contents | |
print_directory_contents('./output_samples') | |
wav_file_path = check_for_wav_in_outputs() | |
print(wav_file_path) | |
return wav_file_path | |
css=""" | |
#col-container{ | |
max-width: 800px; | |
margin: 0 auto; | |
} | |
#chords-examples button{ | |
font-size: 20px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# MusiConGen") | |
gr.Markdown("## Rhythm and Chord Control for Transformer-Based Text-to-Music Generation") | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href='https://musicongen.github.io/musicongen_demo/'> | |
<img src='https://img.shields.io/badge/Project-Page-Green'> | |
</a> | |
</div> | |
""") | |
with gr.Column(): | |
with gr.Group(): | |
prompt_in = gr.Textbox(label="Music description", value="A smooth acid jazz track with a laid-back groove, silky electric piano, and a cool bass, providing a modern take on jazz. Instruments: electric piano, bass, drums.") | |
with gr.Row(): | |
chords = gr.Textbox(label="Chords progression", value='B:min D F#:min E', scale=1.75) | |
duration = gr.Slider(label="Sample duration", minimum=4, maximum=30, step=1, value=30) | |
bpms = gr.Slider(label="BPMs", minimum=50, maximum=220, step=1, value=120) | |
submit_btn = gr.Button("Submit") | |
wav_out = gr.Audio(label="Wav Result", value="./MusiConGen_default_sample_space_example.wav") | |
with gr.Row(): | |
gr.Examples( | |
label = "Audio description examples", | |
examples = [ | |
["A laid-back blues shuffle with a relaxed tempo, warm guitar tones, and a comfortable groove, perfect for a slow dance or a night in. Instruments: electric guitar, bass, drums."], | |
["A smooth acid jazz track with a laid-back groove, silky electric piano, and a cool bass, providing a modern take on jazz. Instruments: electric piano, bass, drums."], | |
["A classic rock n' roll tune with catchy guitar riffs, driving drums, and a pulsating bass line, reminiscent of the golden era of rock. Instruments: electric guitar, bass, drums."], | |
["A high-energy funk tune with slap bass, rhythmic guitar riffs, and a tight horn section, guaranteed to get you grooving. Instruments: bass, guitar, trumpet, saxophone, drums."], | |
["A heavy metal onslaught with double kick drum madness, aggressive guitar riffs, and an unrelenting bass, embodying the spirit of metal. Instruments: electric guitar, bass guitar, drums."] | |
], | |
inputs = [prompt_in] | |
) | |
gr.Examples( | |
label = "Chords progression examples", | |
elem_id = "chords-examples", | |
examples = ['C G A:min F', | |
'A:min F C G', | |
'C F G F', | |
'C A:min F G', | |
'D:min G C A:min', | |
'D:min7 G:7 C:maj7 C:maj7', | |
'F G E:min A:min', | |
'B:min D F#:min E', | |
'F G E A:min', | |
'C Bb F C', | |
'A:min C D F', | |
'B:min F#:min E:min B:min', | |
'B:min7 E:9 A:maj7 A:maj7 C#:7 F#:min7', | |
'F:min G:min Ab Bb', | |
'A:min G F D:min' | |
], | |
inputs = [chords], | |
examples_per_page = 16 | |
) | |
submit_btn.click( | |
fn = infer, | |
inputs = [prompt_in, chords, duration, bpms], | |
outputs = [wav_out] | |
) | |
demo.launch(show_api=False, show_error=True) |