Spaces:
Build error
Build error
Update Overlap Action in Melody
Browse files- app.py +5 -3
- audiocraft/models/musicgen.py +55 -2
- audiocraft/utils/extend.py +36 -23
app.py
CHANGED
|
@@ -100,6 +100,8 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c
|
|
| 100 |
temperature=temperature,
|
| 101 |
cfg_coef=cfg_coef,
|
| 102 |
duration=segment_duration,
|
|
|
|
|
|
|
| 103 |
)
|
| 104 |
|
| 105 |
if melody:
|
|
@@ -201,7 +203,7 @@ def ui(**kwargs):
|
|
| 201 |
include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
|
| 202 |
with gr.Row():
|
| 203 |
title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
|
| 204 |
-
settings_font = gr.Text(label="Settings Font", value="arial.ttf", interactive=True)
|
| 205 |
settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#ffffff", interactive=True)
|
| 206 |
with gr.Row():
|
| 207 |
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
|
@@ -212,8 +214,8 @@ def ui(**kwargs):
|
|
| 212 |
with gr.Row():
|
| 213 |
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
| 214 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
| 215 |
-
temperature = gr.Number(label="Randomness Temperature", value=
|
| 216 |
-
cfg_coef = gr.Number(label="Classifier Free Guidance", value=5.
|
| 217 |
with gr.Row():
|
| 218 |
seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
|
| 219 |
gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
|
|
|
|
| 100 |
temperature=temperature,
|
| 101 |
cfg_coef=cfg_coef,
|
| 102 |
duration=segment_duration,
|
| 103 |
+
two_step_cfg=False,
|
| 104 |
+
rep_penalty=0.5
|
| 105 |
)
|
| 106 |
|
| 107 |
if melody:
|
|
|
|
| 203 |
include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
|
| 204 |
with gr.Row():
|
| 205 |
title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
|
| 206 |
+
settings_font = gr.Text(label="Settings Font", value="./assets/arial.ttf", interactive=True)
|
| 207 |
settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#ffffff", interactive=True)
|
| 208 |
with gr.Row():
|
| 209 |
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
|
|
|
| 214 |
with gr.Row():
|
| 215 |
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
| 216 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
| 217 |
+
temperature = gr.Number(label="Randomness Temperature", value=0.75, precision=None, interactive=True)
|
| 218 |
+
cfg_coef = gr.Number(label="Classifier Free Guidance", value=5.5, precision=None, interactive=True)
|
| 219 |
with gr.Row():
|
| 220 |
seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
|
| 221 |
gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
|
audiocraft/models/musicgen.py
CHANGED
|
@@ -97,7 +97,7 @@ class MusicGen:
|
|
| 97 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
| 98 |
top_p: float = 0.0, temperature: float = 1.0,
|
| 99 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
| 100 |
-
two_step_cfg: bool = False):
|
| 101 |
"""Set the generation parameters for MusicGen.
|
| 102 |
|
| 103 |
Args:
|
|
@@ -110,6 +110,7 @@ class MusicGen:
|
|
| 110 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
| 111 |
instead of batching together the two. This has some impact on how things
|
| 112 |
are padded but seems to have little impact in practice.
|
|
|
|
| 113 |
"""
|
| 114 |
assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
|
| 115 |
self.generation_params = {
|
|
@@ -119,7 +120,7 @@ class MusicGen:
|
|
| 119 |
'top_k': top_k,
|
| 120 |
'top_p': top_p,
|
| 121 |
'cfg_coef': cfg_coef,
|
| 122 |
-
'two_step_cfg': two_step_cfg,
|
| 123 |
}
|
| 124 |
|
| 125 |
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
|
@@ -177,6 +178,58 @@ class MusicGen:
|
|
| 177 |
assert prompt_tokens is None
|
| 178 |
return self._generate_tokens(attributes, prompt_tokens, progress)
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
| 181 |
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
| 182 |
progress: bool = False) -> torch.Tensor:
|
|
|
|
| 97 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
| 98 |
top_p: float = 0.0, temperature: float = 1.0,
|
| 99 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
| 100 |
+
two_step_cfg: bool = False, rep_penalty: float = None):
|
| 101 |
"""Set the generation parameters for MusicGen.
|
| 102 |
|
| 103 |
Args:
|
|
|
|
| 110 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
| 111 |
instead of batching together the two. This has some impact on how things
|
| 112 |
are padded but seems to have little impact in practice.
|
| 113 |
+
rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
|
| 114 |
"""
|
| 115 |
assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
|
| 116 |
self.generation_params = {
|
|
|
|
| 120 |
'top_k': top_k,
|
| 121 |
'top_p': top_p,
|
| 122 |
'cfg_coef': cfg_coef,
|
| 123 |
+
'two_step_cfg': two_step_cfg,
|
| 124 |
}
|
| 125 |
|
| 126 |
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
|
|
|
| 178 |
assert prompt_tokens is None
|
| 179 |
return self._generate_tokens(attributes, prompt_tokens, progress)
|
| 180 |
|
| 181 |
+
def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
| 182 |
+
sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 183 |
+
"""Generate samples conditioned on text and melody and audio prompts.
|
| 184 |
+
Args:
|
| 185 |
+
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
| 186 |
+
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
| 187 |
+
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
| 188 |
+
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
| 189 |
+
a list of [C, T] tensors.
|
| 190 |
+
sample_rate: (int): Sample rate of the melody waveforms.
|
| 191 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 192 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
| 193 |
+
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
| 194 |
+
"""
|
| 195 |
+
if isinstance(melody_wavs, torch.Tensor):
|
| 196 |
+
if melody_wavs.dim() == 2:
|
| 197 |
+
melody_wavs = melody_wavs[None]
|
| 198 |
+
if melody_wavs.dim() != 3:
|
| 199 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 200 |
+
melody_wavs = list(melody_wavs)
|
| 201 |
+
else:
|
| 202 |
+
for melody in melody_wavs:
|
| 203 |
+
if melody is not None:
|
| 204 |
+
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
|
| 205 |
+
|
| 206 |
+
melody_wavs = [
|
| 207 |
+
convert_audio(wav, sample_rate, self.sample_rate, self.audio_channels)
|
| 208 |
+
if wav is not None else None
|
| 209 |
+
for wav in melody_wavs]
|
| 210 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
| 211 |
+
melody_wavs=melody_wavs)
|
| 212 |
+
|
| 213 |
+
if prompt is not None:
|
| 214 |
+
if prompt.dim() == 2:
|
| 215 |
+
prompt = prompt[None]
|
| 216 |
+
if prompt.dim() != 3:
|
| 217 |
+
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
| 218 |
+
prompt = convert_audio(prompt, sample_rate, self.sample_rate, self.audio_channels)
|
| 219 |
+
if descriptions is None:
|
| 220 |
+
descriptions = [None] * len(prompt)
|
| 221 |
+
|
| 222 |
+
if prompt is not None:
|
| 223 |
+
attributes_gen, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
| 224 |
+
|
| 225 |
+
#attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=prompt,
|
| 226 |
+
# melody_wavs=melody_wavs)
|
| 227 |
+
if prompt is not None:
|
| 228 |
+
assert prompt_tokens is not None
|
| 229 |
+
else:
|
| 230 |
+
assert prompt_tokens is None
|
| 231 |
+
return self._generate_tokens(attributes, prompt_tokens, progress)
|
| 232 |
+
|
| 233 |
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
| 234 |
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
| 235 |
progress: bool = False) -> torch.Tensor:
|
audiocraft/utils/extend.py
CHANGED
|
@@ -22,12 +22,15 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
|
|
| 22 |
start_sample = 0
|
| 23 |
|
| 24 |
while total_samples >= segment_samples:
|
|
|
|
|
|
|
|
|
|
| 25 |
end_sample = start_sample + segment_samples
|
| 26 |
segment = audio_data[start_sample:end_sample]
|
| 27 |
segments.append((sr, segment))
|
| 28 |
|
| 29 |
start_sample += segment_samples - overlap_samples
|
| 30 |
-
total_samples -= segment_samples
|
| 31 |
|
| 32 |
# Collect the final segment
|
| 33 |
if total_samples > 0:
|
|
@@ -38,17 +41,16 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
|
|
| 38 |
|
| 39 |
def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:int=1, segment_duration:int=30):
|
| 40 |
# generate audio segments
|
| 41 |
-
melody_segments = separate_audio_segments(melody, segment_duration,
|
| 42 |
|
| 43 |
# Create a list to store the melody tensors for each segment
|
| 44 |
melodys = []
|
| 45 |
output_segments = []
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# Calculate the total number of segments
|
| 48 |
total_segments = max(math.ceil(duration / segment_duration),1)
|
| 49 |
-
# account for overlap
|
| 50 |
-
duration = duration + (max((total_segments - 1),0) * overlap)
|
| 51 |
-
total_segments = max(math.ceil(duration / segment_duration),1)
|
| 52 |
#calc excess duration
|
| 53 |
excess_duration = segment_duration - (total_segments * segment_duration - duration)
|
| 54 |
print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration}")
|
|
@@ -76,11 +78,15 @@ def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:
|
|
| 76 |
torch.manual_seed(seed)
|
| 77 |
for idx, verse in enumerate(melodys):
|
| 78 |
print(f"Generating New Melody Segment {idx + 1}: {text}\r")
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
descriptions=[text],
|
| 81 |
melody_wavs=verse,
|
| 82 |
-
|
| 83 |
-
progress=True
|
|
|
|
| 84 |
)
|
| 85 |
|
| 86 |
# Append the generated output to the list of segments
|
|
@@ -151,24 +157,31 @@ def load_font(font_name, font_size=16):
|
|
| 151 |
Example:
|
| 152 |
font = load_font("Arial.ttf", font_size=20)
|
| 153 |
"""
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
font = ImageFont.truetype(font_name, font_size)
|
| 157 |
-
except (FileNotFoundError, OSError):
|
| 158 |
try:
|
| 159 |
font = ImageFont.truetype(font_name, font_size)
|
| 160 |
-
|
| 161 |
-
|
|
|
|
| 162 |
try:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
print("Font not found.
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
return font
|
| 173 |
|
| 174 |
|
|
|
|
| 22 |
start_sample = 0
|
| 23 |
|
| 24 |
while total_samples >= segment_samples:
|
| 25 |
+
# Collect the segment
|
| 26 |
+
# the end sample is the start sample plus the segment samples,
|
| 27 |
+
# the start sample, after 0, is minus the overlap samples to account for the overlap
|
| 28 |
end_sample = start_sample + segment_samples
|
| 29 |
segment = audio_data[start_sample:end_sample]
|
| 30 |
segments.append((sr, segment))
|
| 31 |
|
| 32 |
start_sample += segment_samples - overlap_samples
|
| 33 |
+
total_samples -= segment_samples
|
| 34 |
|
| 35 |
# Collect the final segment
|
| 36 |
if total_samples > 0:
|
|
|
|
| 41 |
|
| 42 |
def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:int=1, segment_duration:int=30):
|
| 43 |
# generate audio segments
|
| 44 |
+
melody_segments = separate_audio_segments(melody, segment_duration, 0)
|
| 45 |
|
| 46 |
# Create a list to store the melody tensors for each segment
|
| 47 |
melodys = []
|
| 48 |
output_segments = []
|
| 49 |
+
last_chunk = []
|
| 50 |
+
text += ", seed=" + str(seed)
|
| 51 |
|
| 52 |
# Calculate the total number of segments
|
| 53 |
total_segments = max(math.ceil(duration / segment_duration),1)
|
|
|
|
|
|
|
|
|
|
| 54 |
#calc excess duration
|
| 55 |
excess_duration = segment_duration - (total_segments * segment_duration - duration)
|
| 56 |
print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration}")
|
|
|
|
| 78 |
torch.manual_seed(seed)
|
| 79 |
for idx, verse in enumerate(melodys):
|
| 80 |
print(f"Generating New Melody Segment {idx + 1}: {text}\r")
|
| 81 |
+
if output_segments:
|
| 82 |
+
# If this isn't the first segment, use the last chunk of the previous segment as the input
|
| 83 |
+
last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
|
| 84 |
+
output = MODEL.generate_with_all(
|
| 85 |
descriptions=[text],
|
| 86 |
melody_wavs=verse,
|
| 87 |
+
sample_rate=sr,
|
| 88 |
+
progress=True,
|
| 89 |
+
prompt=last_chunk if len(last_chunk) > 0 else None,
|
| 90 |
)
|
| 91 |
|
| 92 |
# Append the generated output to the list of segments
|
|
|
|
| 157 |
Example:
|
| 158 |
font = load_font("Arial.ttf", font_size=20)
|
| 159 |
"""
|
| 160 |
+
font = None
|
| 161 |
+
if not "http" in font_name:
|
|
|
|
|
|
|
| 162 |
try:
|
| 163 |
font = ImageFont.truetype(font_name, font_size)
|
| 164 |
+
except (FileNotFoundError, OSError):
|
| 165 |
+
print("Font not found. Trying to download from local assets folder...\n")
|
| 166 |
+
if font is None:
|
| 167 |
try:
|
| 168 |
+
font = ImageFont.truetype("assets/" + font_name, font_size)
|
| 169 |
+
except (FileNotFoundError, OSError):
|
| 170 |
+
print("Font not found. Trying to download from URL...\n")
|
| 171 |
+
|
| 172 |
+
if font is None:
|
| 173 |
+
try:
|
| 174 |
+
req = requests.get(font_name)
|
| 175 |
+
font = ImageFont.truetype(BytesIO(req.content), font_size)
|
| 176 |
+
except (FileNotFoundError, OSError):
|
| 177 |
+
print(f"Font found: {font_name} Using Hugging Face download font\n")
|
| 178 |
+
|
| 179 |
+
if font is None:
|
| 180 |
+
try:
|
| 181 |
+
font = ImageFont.truetype(hf_hub_download("assets", font_name), encoding="UTF-8")
|
| 182 |
+
except (FileNotFoundError, OSError):
|
| 183 |
+
font = ImageFont.load_default()
|
| 184 |
+
print(f"Font not found: {font_name} Using default font\n")
|
| 185 |
return font
|
| 186 |
|
| 187 |
|