thecollabagepatch commited on
Commit
6d0aea5
·
1 Parent(s): 241e975

use tail end of longer contexts

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -141,20 +141,19 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
141
  wav.samples[:n] *= env
142
  wav.samples[-n:] *= env[::-1]
143
 
144
- def take_bar_aligned_tail(wav: au.Waveform,
145
- bpm: float,
146
- beats_per_bar: int,
147
- ctx_seconds: float) -> au.Waveform:
148
  """
149
  Return the LAST N bars whose duration is as close as possible to ctx_seconds,
150
  anchored to the end of `wav`, and bar-aligned.
151
  """
152
- spb = (60.0 / bpm) * beats_per_bar # seconds per bar
 
153
  bars_needed = max(1, int(round(ctx_seconds / spb)))
154
- tail_seconds = bars_needed * spb # exact multiple of bars
 
 
155
  n = int(round(tail_seconds * wav.sample_rate))
156
  if n >= wav.samples.shape[0]:
157
- # Input shorter than desired tail: keep whole thing (your existing behavior will tile)
158
  return wav
159
  return au.Waveform(wav.samples[-n:], wav.sample_rate)
160
 
@@ -187,6 +186,8 @@ def generate_loop_continuation_with_mrt(
187
  beats_per_bar=beats_per_bar,
188
  ctx_seconds=ctx_seconds
189
  )
 
 
190
 
191
  # Encode ONLY the tail (so we condition on recent audio)
192
  tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
 
141
  wav.samples[:n] *= env
142
  wav.samples[-n:] *= env[::-1]
143
 
144
+ def take_bar_aligned_tail(wav, bpm, beats_per_bar, ctx_seconds, max_bars=None):
 
 
 
145
  """
146
  Return the LAST N bars whose duration is as close as possible to ctx_seconds,
147
  anchored to the end of `wav`, and bar-aligned.
148
  """
149
+ spb = (60.0 / bpm) * beats_per_bar
150
+ print(f"[MRT] bpm={bpm}, spb={spb:.4f}, ctx_frames={mrt.config.context_length_frames}, fps={mrt.codec.frame_rate}")
151
  bars_needed = max(1, int(round(ctx_seconds / spb)))
152
+ if max_bars is not None:
153
+ bars_needed = min(bars_needed, max_bars)
154
+ tail_seconds = bars_needed * spb
155
  n = int(round(tail_seconds * wav.sample_rate))
156
  if n >= wav.samples.shape[0]:
 
157
  return wav
158
  return au.Waveform(wav.samples[-n:], wav.sample_rate)
159
 
 
186
  beats_per_bar=beats_per_bar,
187
  ctx_seconds=ctx_seconds
188
  )
189
+ print(f"[MRT] context tail: {ctx_seconds:.2f}s ≈ {loop_for_context.samples.shape[0]/loop_for_context.sample_rate:.2f}s, "
190
+ f"sr={loop_for_context.sample_rate}")
191
 
192
  # Encode ONLY the tail (so we condition on recent audio)
193
  tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)