Spaces:
Paused
Paused
Commit
·
c4aed03
1
Parent(s):
09491d9
drop first bar
Browse files
app.py
CHANGED
@@ -168,94 +168,84 @@ def generate_loop_continuation_with_mrt(
|
|
168 |
style_weights=None,
|
169 |
bars: int = 8,
|
170 |
beats_per_bar: int = 4,
|
171 |
-
loop_weight: float = 1.0,
|
172 |
-
loudness_mode: str = "auto",
|
173 |
-
loudness_headroom_db: float = 1.0,
|
|
|
174 |
):
|
175 |
-
# Load
|
176 |
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
|
177 |
|
178 |
-
#
|
179 |
-
codec_fps
|
180 |
-
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
|
|
|
181 |
|
182 |
-
# ✅ NEW: take bar-aligned TAIL for context, if input is long enough
|
183 |
-
loop_for_context = take_bar_aligned_tail(
|
184 |
-
wav=loop,
|
185 |
-
bpm=bpm,
|
186 |
-
beats_per_bar=beats_per_bar,
|
187 |
-
ctx_seconds=ctx_seconds
|
188 |
-
)
|
189 |
-
print(f"[MRT] context tail: {ctx_seconds:.2f}s ≈ {loop_for_context.samples.shape[0]/loop_for_context.sample_rate:.2f}s, "
|
190 |
-
f"sr={loop_for_context.sample_rate}")
|
191 |
-
|
192 |
-
# Encode ONLY the tail (so we condition on recent audio)
|
193 |
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
|
194 |
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
|
195 |
|
196 |
-
#
|
197 |
context_tokens = make_bar_aligned_context(
|
198 |
-
tokens,
|
199 |
-
|
200 |
-
fps=int(mrt.codec.frame_rate),
|
201 |
-
ctx_frames=mrt.config.context_length_frames,
|
202 |
-
beats_per_bar=beats_per_bar,
|
203 |
)
|
204 |
state = mrt.init_state()
|
205 |
state.context_tokens = context_tokens
|
206 |
|
207 |
-
#
|
208 |
-
|
209 |
-
embeds = []
|
210 |
-
weights = []
|
211 |
-
|
212 |
-
# loop embedding
|
213 |
-
loop_embed = mrt.embed_style(loop)
|
214 |
-
embeds.append(loop_embed)
|
215 |
-
weights.append(float(loop_weight)) # <--- use requested loop weight
|
216 |
-
|
217 |
-
# extra styles
|
218 |
if extra_styles:
|
219 |
for i, s in enumerate(extra_styles):
|
220 |
if s.strip():
|
221 |
embeds.append(mrt.embed_style(s.strip()))
|
222 |
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
|
223 |
weights.append(float(w))
|
224 |
-
|
225 |
-
# Prevent all-zero weights; normalize
|
226 |
-
wsum = float(sum(weights))
|
227 |
-
if wsum <= 0.0:
|
228 |
-
# fallback: rely on loop to avoid NaNs
|
229 |
-
weights = [1.0] + [0.0] * (len(weights) - 1)
|
230 |
-
wsum = 1.0
|
231 |
-
|
232 |
weights = [w / wsum for w in weights]
|
233 |
-
|
234 |
-
# weighted sum -> single style vector (match dtype)
|
235 |
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
|
236 |
|
237 |
-
#
|
238 |
seconds_per_bar = beats_per_bar * (60.0 / bpm)
|
239 |
-
total_secs
|
|
|
|
|
|
|
|
|
|
|
240 |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
|
241 |
-
steps = int(math.ceil(
|
242 |
|
243 |
# Generate
|
244 |
chunks = []
|
245 |
for _ in range(steps):
|
246 |
-
wav, state = mrt.generate_chunk(state=state, style=combined_style)
|
247 |
chunks.append(wav)
|
248 |
|
249 |
-
# Stitch
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
apply_micro_fades(out, 5)
|
253 |
-
|
|
|
254 |
out, loud_stats = match_loudness_to_reference(
|
255 |
ref=loop, target=out,
|
256 |
-
method=loudness_mode,
|
257 |
-
headroom_db=loudness_headroom_db,
|
258 |
)
|
|
|
259 |
return out, loud_stats
|
260 |
|
261 |
# ----------------------------
|
@@ -296,7 +286,8 @@ def generate(
|
|
296 |
guidance_weight: float = Form(5.0),
|
297 |
temperature: float = Form(1.1),
|
298 |
topk: int = Form(40),
|
299 |
-
target_sample_rate: int | None = Form(None),
|
|
|
300 |
):
|
301 |
# Read file
|
302 |
data = loop_audio.file.read()
|
@@ -327,6 +318,7 @@ def generate(
|
|
327 |
loop_weight=loop_weight,
|
328 |
loudness_mode=loudness_mode,
|
329 |
loudness_headroom_db=loudness_headroom_db,
|
|
|
330 |
)
|
331 |
|
332 |
# 1) Figure out the desired SR
|
|
|
168 |
style_weights=None,
|
169 |
bars: int = 8,
|
170 |
beats_per_bar: int = 4,
|
171 |
+
loop_weight: float = 1.0,
|
172 |
+
loudness_mode: str = "auto",
|
173 |
+
loudness_headroom_db: float = 1.0,
|
174 |
+
intro_bars_to_drop: int = 0, # <— NEW
|
175 |
):
|
176 |
+
# Load & prep (unchanged)
|
177 |
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
|
178 |
|
179 |
+
# Use tail for context (your recent change)
|
180 |
+
codec_fps = float(mrt.codec.frame_rate)
|
181 |
+
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
|
182 |
+
loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
|
185 |
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
|
186 |
|
187 |
+
# Bar-aligned token window (unchanged)
|
188 |
context_tokens = make_bar_aligned_context(
|
189 |
+
tokens, bpm=bpm, fps=int(mrt.codec.frame_rate),
|
190 |
+
ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
|
|
|
|
|
|
|
191 |
)
|
192 |
state = mrt.init_state()
|
193 |
state.context_tokens = context_tokens
|
194 |
|
195 |
+
# STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias)
|
196 |
+
loop_embed = mrt.embed_style(loop_for_context)
|
197 |
+
embeds, weights = [loop_embed], [float(loop_weight)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
if extra_styles:
|
199 |
for i, s in enumerate(extra_styles):
|
200 |
if s.strip():
|
201 |
embeds.append(mrt.embed_style(s.strip()))
|
202 |
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
|
203 |
weights.append(float(w))
|
204 |
+
wsum = float(sum(weights)) or 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
weights = [w / wsum for w in weights]
|
|
|
|
|
206 |
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
|
207 |
|
208 |
+
# --- Length math ---
|
209 |
seconds_per_bar = beats_per_bar * (60.0 / bpm)
|
210 |
+
total_secs = bars * seconds_per_bar
|
211 |
+
drop_bars = max(0, int(intro_bars_to_drop))
|
212 |
+
drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars
|
213 |
+
gen_total_secs = total_secs + drop_secs # generate extra
|
214 |
+
|
215 |
+
# Chunk scheduling to cover gen_total_secs
|
216 |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
|
217 |
+
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim
|
218 |
|
219 |
# Generate
|
220 |
chunks = []
|
221 |
for _ in range(steps):
|
222 |
+
wav, state = mrt.generate_chunk(state=state, style=combined_style)
|
223 |
chunks.append(wav)
|
224 |
|
225 |
+
# Stitch continuous audio
|
226 |
+
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
|
227 |
+
|
228 |
+
# Trim to generated length (bars + dropped bars)
|
229 |
+
stitched = hard_trim_seconds(stitched, gen_total_secs)
|
230 |
+
|
231 |
+
# 👉 Drop the intro bars
|
232 |
+
if drop_secs > 0:
|
233 |
+
n_drop = int(round(drop_secs * stitched.sample_rate))
|
234 |
+
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
|
235 |
+
|
236 |
+
# Final exact-length trim to requested bars
|
237 |
+
out = hard_trim_seconds(stitched, total_secs)
|
238 |
+
|
239 |
+
# Final polish AFTER drop
|
240 |
+
out = out.peak_normalize(0.95)
|
241 |
apply_micro_fades(out, 5)
|
242 |
+
|
243 |
+
# Loudness match to input (after drop) so bar 1 sits right
|
244 |
out, loud_stats = match_loudness_to_reference(
|
245 |
ref=loop, target=out,
|
246 |
+
method=loudness_mode, headroom_db=loudness_headroom_db
|
|
|
247 |
)
|
248 |
+
|
249 |
return out, loud_stats
|
250 |
|
251 |
# ----------------------------
|
|
|
286 |
guidance_weight: float = Form(5.0),
|
287 |
temperature: float = Form(1.1),
|
288 |
topk: int = Form(40),
|
289 |
+
target_sample_rate: int | None = Form(None),
|
290 |
+
intro_bars_to_drop: int = Form(0), # <— NEW
|
291 |
):
|
292 |
# Read file
|
293 |
data = loop_audio.file.read()
|
|
|
318 |
loop_weight=loop_weight,
|
319 |
loudness_mode=loudness_mode,
|
320 |
loudness_headroom_db=loudness_headroom_db,
|
321 |
+
intro_bars_to_drop=intro_bars_to_drop, # <— pass through
|
322 |
)
|
323 |
|
324 |
# 1) Figure out the desired SR
|