Spaces:
Running
on
L40S
Running
on
L40S
Commit
·
6d0aea5
1
Parent(s):
241e975
use tail end of longer contexts
Browse files
app.py
CHANGED
@@ -141,20 +141,19 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
|
|
141 |
wav.samples[:n] *= env
|
142 |
wav.samples[-n:] *= env[::-1]
|
143 |
|
144 |
-
def take_bar_aligned_tail(wav
|
145 |
-
bpm: float,
|
146 |
-
beats_per_bar: int,
|
147 |
-
ctx_seconds: float) -> au.Waveform:
|
148 |
"""
|
149 |
Return the LAST N bars whose duration is as close as possible to ctx_seconds,
|
150 |
anchored to the end of `wav`, and bar-aligned.
|
151 |
"""
|
152 |
-
spb = (60.0 / bpm) * beats_per_bar
|
|
|
153 |
bars_needed = max(1, int(round(ctx_seconds / spb)))
|
154 |
-
|
|
|
|
|
155 |
n = int(round(tail_seconds * wav.sample_rate))
|
156 |
if n >= wav.samples.shape[0]:
|
157 |
-
# Input shorter than desired tail: keep whole thing (your existing behavior will tile)
|
158 |
return wav
|
159 |
return au.Waveform(wav.samples[-n:], wav.sample_rate)
|
160 |
|
@@ -187,6 +186,8 @@ def generate_loop_continuation_with_mrt(
|
|
187 |
beats_per_bar=beats_per_bar,
|
188 |
ctx_seconds=ctx_seconds
|
189 |
)
|
|
|
|
|
190 |
|
191 |
# Encode ONLY the tail (so we condition on recent audio)
|
192 |
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
|
|
|
141 |
wav.samples[:n] *= env
|
142 |
wav.samples[-n:] *= env[::-1]
|
143 |
|
144 |
+
def take_bar_aligned_tail(wav, bpm, beats_per_bar, ctx_seconds, max_bars=None):
|
|
|
|
|
|
|
145 |
"""
|
146 |
Return the LAST N bars whose duration is as close as possible to ctx_seconds,
|
147 |
anchored to the end of `wav`, and bar-aligned.
|
148 |
"""
|
149 |
+
spb = (60.0 / bpm) * beats_per_bar
|
150 |
+
print(f"[MRT] bpm={bpm}, spb={spb:.4f}, ctx_frames={mrt.config.context_length_frames}, fps={mrt.codec.frame_rate}")
|
151 |
bars_needed = max(1, int(round(ctx_seconds / spb)))
|
152 |
+
if max_bars is not None:
|
153 |
+
bars_needed = min(bars_needed, max_bars)
|
154 |
+
tail_seconds = bars_needed * spb
|
155 |
n = int(round(tail_seconds * wav.sample_rate))
|
156 |
if n >= wav.samples.shape[0]:
|
|
|
157 |
return wav
|
158 |
return au.Waveform(wav.samples[-n:], wav.sample_rate)
|
159 |
|
|
|
186 |
beats_per_bar=beats_per_bar,
|
187 |
ctx_seconds=ctx_seconds
|
188 |
)
|
189 |
+
print(f"[MRT] context tail: {ctx_seconds:.2f}s ≈ {loop_for_context.samples.shape[0]/loop_for_context.sample_rate:.2f}s, "
|
190 |
+
f"sr={loop_for_context.sample_rate}")
|
191 |
|
192 |
# Encode ONLY the tail (so we condition on recent audio)
|
193 |
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
|