Spaces:
Paused
Paused
from magenta_rt import system, audio as au | |
import numpy as np | |
from fastapi import FastAPI, UploadFile, File, Form | |
import tempfile, io, base64, math, threading | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import contextmanager | |
import soundfile as sf | |
import numpy as np | |
from math import gcd | |
from scipy.signal import resample_poly | |
def mrt_overrides(mrt, **kwargs): | |
"""Temporarily set attributes on MRT if they exist; restore after.""" | |
old = {} | |
try: | |
for k, v in kwargs.items(): | |
if hasattr(mrt, k): | |
old[k] = getattr(mrt, k) | |
setattr(mrt, k, v) | |
yield | |
finally: | |
for k, v in old.items(): | |
setattr(mrt, k, v) | |
# loudness utils | |
try: | |
import pyloudnorm as pyln | |
_HAS_LOUDNORM = True | |
except Exception: | |
_HAS_LOUDNORM = False | |
def _measure_lufs(wav: au.Waveform) -> float: | |
# pyloudnorm expects float32/float64, shape (n,) or (n, ch) | |
meter = pyln.Meter(wav.sample_rate) # defaults to BS.1770-4 | |
return float(meter.integrated_loudness(wav.samples)) | |
def _rms(x: np.ndarray) -> float: | |
if x.size == 0: return 0.0 | |
return float(np.sqrt(np.mean(x**2))) | |
def match_loudness_to_reference( | |
ref: au.Waveform, | |
target: au.Waveform, | |
method: str = "auto", # "auto"|"lufs"|"rms"|"none" | |
headroom_db: float = 1.0 | |
) -> tuple[au.Waveform, dict]: | |
""" | |
Scales `target` to match `ref` loudness. Returns (adjusted_wave, stats). | |
""" | |
stats = {"method": method, "applied_gain_db": 0.0} | |
if method == "none": | |
return target, stats | |
if method == "auto": | |
method = "lufs" if _HAS_LOUDNORM else "rms" | |
if method == "lufs" and _HAS_LOUDNORM: | |
L_ref = _measure_lufs(ref) | |
L_tgt = _measure_lufs(target) | |
delta_db = L_ref - L_tgt | |
gain = 10.0 ** (delta_db / 20.0) | |
y = target.samples.astype(np.float32) * gain | |
stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db}) | |
else: | |
# RMS fallback | |
ra = _rms(ref.samples) | |
rb = _rms(target.samples) | |
if rb <= 1e-12: | |
return target, stats | |
gain = ra / rb | |
y = target.samples.astype(np.float32) * gain | |
stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))}) | |
# simple peak “limiter” to keep headroom | |
limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS | |
peak = float(np.max(np.abs(y))) if y.size else 0.0 | |
if peak > limit: | |
y *= (limit / peak) | |
stats["post_peak_limited"] = True | |
else: | |
stats["post_peak_limited"] = False | |
target.samples = y.astype(np.float32) | |
return target, stats | |
# ---------------------------- | |
# Crossfade stitch (your good path) | |
# ---------------------------- | |
def stitch_generated(chunks, sr, xfade_s): | |
if not chunks: | |
raise ValueError("no chunks") | |
xfade_n = int(round(xfade_s * sr)) | |
if xfade_n <= 0: | |
return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr) | |
t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32) | |
eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None] | |
first = chunks[0].samples | |
if first.shape[0] < xfade_n: | |
raise ValueError("chunk shorter than crossfade prefix") | |
out = first[xfade_n:].copy() # drop model pre-roll | |
for i in range(1, len(chunks)): | |
cur = chunks[i].samples | |
if cur.shape[0] < xfade_n: | |
continue | |
head, tail = cur[:xfade_n], cur[xfade_n:] | |
mixed = out[-xfade_n:] * eq_out + head * eq_in | |
out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0) | |
return au.Waveform(out, sr) | |
# ---------------------------- | |
# Bar-aligned token context | |
# ---------------------------- | |
def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4): | |
frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps | |
frames_per_bar = int(round(frames_per_bar_f)) | |
if abs(frames_per_bar - frames_per_bar_f) > 1e-3: | |
reps = int(np.ceil(ctx_frames / len(tokens))) | |
return np.tile(tokens, (reps, 1))[-ctx_frames:] | |
reps = int(np.ceil(ctx_frames / len(tokens))) | |
tiled = np.tile(tokens, (reps, 1)) | |
end = (len(tiled) // frames_per_bar) * frames_per_bar | |
if end < ctx_frames: | |
return tiled[-ctx_frames:] | |
start = end - ctx_frames | |
return tiled[start:end] | |
def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform: | |
n = int(round(seconds * wav.sample_rate)) | |
return au.Waveform(wav.samples[:n], wav.sample_rate) | |
def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None: | |
n = int(wav.sample_rate * ms / 1000.0) | |
if n > 0 and wav.samples.shape[0] > 2*n: | |
env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None] | |
wav.samples[:n] *= env | |
wav.samples[-n:] *= env[::-1] | |
def take_bar_aligned_tail(wav, bpm, beats_per_bar, ctx_seconds, max_bars=None): | |
""" | |
Return the LAST N bars whose duration is as close as possible to ctx_seconds, | |
anchored to the end of `wav`, and bar-aligned. | |
""" | |
spb = (60.0 / bpm) * beats_per_bar | |
bars_needed = max(1, int(round(ctx_seconds / spb))) | |
if max_bars is not None: | |
bars_needed = min(bars_needed, max_bars) | |
tail_seconds = bars_needed * spb | |
n = int(round(tail_seconds * wav.sample_rate)) | |
if n >= wav.samples.shape[0]: | |
return wav | |
return au.Waveform(wav.samples[-n:], wav.sample_rate) | |
# ---------------------------- | |
# Main generation (single combined style vector) | |
# ---------------------------- | |
def generate_loop_continuation_with_mrt( | |
mrt, | |
input_wav_path: str, | |
bpm: float, | |
extra_styles=None, | |
style_weights=None, | |
bars: int = 8, | |
beats_per_bar: int = 4, | |
loop_weight: float = 1.0, | |
loudness_mode: str = "auto", | |
loudness_headroom_db: float = 1.0, | |
intro_bars_to_drop: int = 0, # <— NEW | |
): | |
# Load & prep (unchanged) | |
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() | |
# Use tail for context (your recent change) | |
codec_fps = float(mrt.codec.frame_rate) | |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps | |
loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) | |
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32) | |
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] | |
# Bar-aligned token window (unchanged) | |
context_tokens = make_bar_aligned_context( | |
tokens, bpm=bpm, fps=int(mrt.codec.frame_rate), | |
ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar | |
) | |
state = mrt.init_state() | |
state.context_tokens = context_tokens | |
# STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias) | |
loop_embed = mrt.embed_style(loop_for_context) | |
embeds, weights = [loop_embed], [float(loop_weight)] | |
if extra_styles: | |
for i, s in enumerate(extra_styles): | |
if s.strip(): | |
embeds.append(mrt.embed_style(s.strip())) | |
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 | |
weights.append(float(w)) | |
wsum = float(sum(weights)) or 1.0 | |
weights = [w / wsum for w in weights] | |
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) | |
# --- Length math --- | |
seconds_per_bar = beats_per_bar * (60.0 / bpm) | |
total_secs = bars * seconds_per_bar | |
drop_bars = max(0, int(intro_bars_to_drop)) | |
drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars | |
gen_total_secs = total_secs + drop_secs # generate extra | |
# Chunk scheduling to cover gen_total_secs | |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0 | |
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim | |
# Generate | |
chunks = [] | |
for _ in range(steps): | |
wav, state = mrt.generate_chunk(state=state, style=combined_style) | |
chunks.append(wav) | |
# Stitch continuous audio | |
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() | |
# Trim to generated length (bars + dropped bars) | |
stitched = hard_trim_seconds(stitched, gen_total_secs) | |
# 👉 Drop the intro bars | |
if drop_secs > 0: | |
n_drop = int(round(drop_secs * stitched.sample_rate)) | |
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) | |
# Final exact-length trim to requested bars | |
out = hard_trim_seconds(stitched, total_secs) | |
# Final polish AFTER drop | |
out = out.peak_normalize(0.95) | |
apply_micro_fades(out, 5) | |
# Loudness match to input (after drop) so bar 1 sits right | |
out, loud_stats = match_loudness_to_reference( | |
ref=loop, target=out, | |
method=loudness_mode, headroom_db=loudness_headroom_db | |
) | |
return out, loud_stats | |
# ---------------------------- | |
# FastAPI app with lazy, thread-safe model init | |
# ---------------------------- | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # or lock to your domain(s) | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
_MRT = None | |
_MRT_LOCK = threading.Lock() | |
def get_mrt(): | |
global _MRT | |
if _MRT is None: | |
with _MRT_LOCK: | |
if _MRT is None: | |
_MRT = system.MagentaRT(tag="base", guidance_weight=1.0, device="gpu", lazy=False) | |
return _MRT | |
def generate( | |
loop_audio: UploadFile = File(...), | |
bpm: float = Form(...), | |
bars: int = Form(8), | |
beats_per_bar: int = Form(4), | |
styles: str = Form("acid house"), | |
style_weights: str = Form(""), | |
loop_weight: float = Form(1.0), | |
loudness_mode: str = Form("auto"), | |
loudness_headroom_db: float = Form(1.0), | |
guidance_weight: float = Form(5.0), | |
temperature: float = Form(1.1), | |
topk: int = Form(40), | |
target_sample_rate: int | None = Form(None), | |
intro_bars_to_drop: int = Form(0), # <— NEW | |
): | |
# Read file | |
data = loop_audio.file.read() | |
if not data: | |
return {"error": "Empty file"} | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
tmp.write(data) | |
tmp_path = tmp.name | |
# Parse styles + weights | |
extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()] | |
weights = [float(x) for x in style_weights.split(",")] if style_weights else None | |
mrt = get_mrt() # warm once, in this worker thread | |
# Temporarily override MRT inference knobs for this request | |
with mrt_overrides(mrt, | |
guidance_weight=guidance_weight, | |
temperature=temperature, | |
topk=topk): | |
wav, loud_stats = generate_loop_continuation_with_mrt( | |
mrt, | |
input_wav_path=tmp_path, | |
bpm=bpm, | |
extra_styles=extra_styles, | |
style_weights=weights, | |
bars=bars, | |
beats_per_bar=beats_per_bar, | |
loop_weight=loop_weight, | |
loudness_mode=loudness_mode, | |
loudness_headroom_db=loudness_headroom_db, | |
intro_bars_to_drop=intro_bars_to_drop, # <— pass through | |
) | |
# 1) Figure out the desired SR | |
inp_info = sf.info(tmp_path) | |
input_sr = int(inp_info.samplerate) | |
target_sr = int(target_sample_rate or input_sr) | |
# 2) Convert magenta output to target_sr if needed | |
# wav.samples: shape [num_samples, num_channels], float32/-1..1 (per your code) | |
cur_sr = int(mrt.sample_rate) | |
x = wav.samples # np.ndarray (S, C) | |
if cur_sr != target_sr: | |
g = gcd(cur_sr, target_sr) | |
up, down = target_sr // g, cur_sr // g | |
# ensure 2D shape (S, C) | |
x = wav.samples | |
if x.ndim == 1: | |
x = x[:, None] | |
y = np.column_stack([resample_poly(x[:, ch], up, down) for ch in range(x.shape[1])]) | |
else: | |
y = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None] | |
# 3) Snap to exact frame count for loop-perfect length | |
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar) | |
expected_len = int(round(float(bars) * seconds_per_bar * target_sr)) | |
if y.shape[0] < expected_len: | |
pad = np.zeros((expected_len - y.shape[0], y.shape[1]), dtype=y.dtype) | |
y = np.vstack([y, pad]) | |
elif y.shape[0] > expected_len: | |
y = y[:expected_len, :] | |
total_samples = int(y.shape[0]) | |
loop_duration_seconds = total_samples / float(target_sr) | |
# 4) Write y into buf as WAV @ target_sr | |
buf = io.BytesIO() | |
sf.write(buf, y, target_sr, subtype="FLOAT", format="WAV") | |
buf.seek(0) | |
audio_b64 = base64.b64encode(buf.read()).decode("utf-8") | |
# 5) Update metadata to be authoritative | |
metadata = { | |
"bpm": int(round(bpm)), | |
"bars": int(bars), | |
"beats_per_bar": int(beats_per_bar), | |
"styles": extra_styles, | |
"style_weights": weights, | |
"loop_weight": loop_weight, | |
"loudness": loud_stats, | |
"sample_rate": int(target_sr), | |
"channels": int(y.shape[1]), | |
"crossfade_seconds": mrt.config.crossfade_length, | |
"total_samples": total_samples, | |
"seconds_per_bar": seconds_per_bar, | |
"loop_duration_seconds": loop_duration_seconds, | |
"guidance_weight": guidance_weight, | |
"temperature": temperature, | |
"topk": topk, | |
} | |
return {"audio_base64": audio_b64, "metadata": metadata} | |
def health(): | |
return {"ok": True} |