thecollabagepatch commited on
Commit
c4aed03
·
1 Parent(s): 09491d9

drop first bar

Browse files
Files changed (1) hide show
  1. app.py +48 -56
app.py CHANGED
@@ -168,94 +168,84 @@ def generate_loop_continuation_with_mrt(
168
  style_weights=None,
169
  bars: int = 8,
170
  beats_per_bar: int = 4,
171
- loop_weight: float = 1.0, # NEW
172
- loudness_mode: str = "auto", # "auto"|"lufs"|"rms"|"none"
173
- loudness_headroom_db: float = 1.0, # for the peak guard
 
174
  ):
175
- # Load loop & put into model SR/channels
176
  loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
177
 
178
- # Compute the model's desired context seconds (e.g., 250 frames / 25 fps = 10s)
179
- codec_fps = float(mrt.codec.frame_rate)
180
- ctx_seconds = float(mrt.config.context_length_frames) / codec_fps # typically 10.0s
 
181
 
182
- # ✅ NEW: take bar-aligned TAIL for context, if input is long enough
183
- loop_for_context = take_bar_aligned_tail(
184
- wav=loop,
185
- bpm=bpm,
186
- beats_per_bar=beats_per_bar,
187
- ctx_seconds=ctx_seconds
188
- )
189
- print(f"[MRT] context tail: {ctx_seconds:.2f}s ≈ {loop_for_context.samples.shape[0]/loop_for_context.sample_rate:.2f}s, "
190
- f"sr={loop_for_context.sample_rate}")
191
-
192
- # Encode ONLY the tail (so we condition on recent audio)
193
  tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
194
  tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
195
 
196
- # Context
197
  context_tokens = make_bar_aligned_context(
198
- tokens,
199
- bpm=bpm,
200
- fps=int(mrt.codec.frame_rate),
201
- ctx_frames=mrt.config.context_length_frames,
202
- beats_per_bar=beats_per_bar,
203
  )
204
  state = mrt.init_state()
205
  state.context_tokens = context_tokens
206
 
207
- # ---------- STYLE: weighted avg into ONE vector ----------
208
- # Base embed from loop with adjustable loop_weight
209
- embeds = []
210
- weights = []
211
-
212
- # loop embedding
213
- loop_embed = mrt.embed_style(loop)
214
- embeds.append(loop_embed)
215
- weights.append(float(loop_weight)) # <--- use requested loop weight
216
-
217
- # extra styles
218
  if extra_styles:
219
  for i, s in enumerate(extra_styles):
220
  if s.strip():
221
  embeds.append(mrt.embed_style(s.strip()))
222
  w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
223
  weights.append(float(w))
224
-
225
- # Prevent all-zero weights; normalize
226
- wsum = float(sum(weights))
227
- if wsum <= 0.0:
228
- # fallback: rely on loop to avoid NaNs
229
- weights = [1.0] + [0.0] * (len(weights) - 1)
230
- wsum = 1.0
231
-
232
  weights = [w / wsum for w in weights]
233
-
234
- # weighted sum -> single style vector (match dtype)
235
  combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
236
 
237
- # Chunks to cover exact bars
238
  seconds_per_bar = beats_per_bar * (60.0 / bpm)
239
- total_secs = bars * seconds_per_bar
 
 
 
 
 
240
  chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
241
- steps = int(math.ceil(total_secs / chunk_secs)) + 1 # pad then trim
242
 
243
  # Generate
244
  chunks = []
245
  for _ in range(steps):
246
- wav, state = mrt.generate_chunk(state=state, style=combined_style) # ONE style vector
247
  chunks.append(wav)
248
 
249
- # Stitch -> trim -> polish
250
- out = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
251
- out = hard_trim_seconds(out, total_secs).peak_normalize(0.95)
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  apply_micro_fades(out, 5)
253
- # Loudness match to the *input loop* so the return level feels consistent
 
254
  out, loud_stats = match_loudness_to_reference(
255
  ref=loop, target=out,
256
- method=loudness_mode,
257
- headroom_db=loudness_headroom_db,
258
  )
 
259
  return out, loud_stats
260
 
261
  # ----------------------------
@@ -296,7 +286,8 @@ def generate(
296
  guidance_weight: float = Form(5.0),
297
  temperature: float = Form(1.1),
298
  topk: int = Form(40),
299
- target_sample_rate: int | None = Form(None), # <-- add this
 
300
  ):
301
  # Read file
302
  data = loop_audio.file.read()
@@ -327,6 +318,7 @@ def generate(
327
  loop_weight=loop_weight,
328
  loudness_mode=loudness_mode,
329
  loudness_headroom_db=loudness_headroom_db,
 
330
  )
331
 
332
  # 1) Figure out the desired SR
 
168
  style_weights=None,
169
  bars: int = 8,
170
  beats_per_bar: int = 4,
171
+ loop_weight: float = 1.0,
172
+ loudness_mode: str = "auto",
173
+ loudness_headroom_db: float = 1.0,
174
+ intro_bars_to_drop: int = 0, # <— NEW
175
  ):
176
+ # Load & prep (unchanged)
177
  loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
178
 
179
+ # Use tail for context (your recent change)
180
+ codec_fps = float(mrt.codec.frame_rate)
181
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
182
+ loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
183
 
 
 
 
 
 
 
 
 
 
 
 
184
  tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
185
  tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
186
 
187
+ # Bar-aligned token window (unchanged)
188
  context_tokens = make_bar_aligned_context(
189
+ tokens, bpm=bpm, fps=int(mrt.codec.frame_rate),
190
+ ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
 
 
 
191
  )
192
  state = mrt.init_state()
193
  state.context_tokens = context_tokens
194
 
195
+ # STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias)
196
+ loop_embed = mrt.embed_style(loop_for_context)
197
+ embeds, weights = [loop_embed], [float(loop_weight)]
 
 
 
 
 
 
 
 
198
  if extra_styles:
199
  for i, s in enumerate(extra_styles):
200
  if s.strip():
201
  embeds.append(mrt.embed_style(s.strip()))
202
  w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
203
  weights.append(float(w))
204
+ wsum = float(sum(weights)) or 1.0
 
 
 
 
 
 
 
205
  weights = [w / wsum for w in weights]
 
 
206
  combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
207
 
208
+ # --- Length math ---
209
  seconds_per_bar = beats_per_bar * (60.0 / bpm)
210
+ total_secs = bars * seconds_per_bar
211
+ drop_bars = max(0, int(intro_bars_to_drop))
212
+ drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars
213
+ gen_total_secs = total_secs + drop_secs # generate extra
214
+
215
+ # Chunk scheduling to cover gen_total_secs
216
  chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
217
+ steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim
218
 
219
  # Generate
220
  chunks = []
221
  for _ in range(steps):
222
+ wav, state = mrt.generate_chunk(state=state, style=combined_style)
223
  chunks.append(wav)
224
 
225
+ # Stitch continuous audio
226
+ stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
227
+
228
+ # Trim to generated length (bars + dropped bars)
229
+ stitched = hard_trim_seconds(stitched, gen_total_secs)
230
+
231
+ # 👉 Drop the intro bars
232
+ if drop_secs > 0:
233
+ n_drop = int(round(drop_secs * stitched.sample_rate))
234
+ stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
235
+
236
+ # Final exact-length trim to requested bars
237
+ out = hard_trim_seconds(stitched, total_secs)
238
+
239
+ # Final polish AFTER drop
240
+ out = out.peak_normalize(0.95)
241
  apply_micro_fades(out, 5)
242
+
243
+ # Loudness match to input (after drop) so bar 1 sits right
244
  out, loud_stats = match_loudness_to_reference(
245
  ref=loop, target=out,
246
+ method=loudness_mode, headroom_db=loudness_headroom_db
 
247
  )
248
+
249
  return out, loud_stats
250
 
251
  # ----------------------------
 
286
  guidance_weight: float = Form(5.0),
287
  temperature: float = Form(1.1),
288
  topk: int = Form(40),
289
+ target_sample_rate: int | None = Form(None),
290
+ intro_bars_to_drop: int = Form(0), # <— NEW
291
  ):
292
  # Read file
293
  data = loop_audio.file.read()
 
318
  loop_weight=loop_weight,
319
  loudness_mode=loudness_mode,
320
  loudness_headroom_db=loudness_headroom_db,
321
+ intro_bars_to_drop=intro_bars_to_drop, # <— pass through
322
  )
323
 
324
  # 1) Figure out the desired SR