from magenta_rt import system, audio as au import numpy as np from fastapi import FastAPI, UploadFile, File, Form import tempfile, io, base64, math, threading from fastapi.middleware.cors import CORSMiddleware from contextlib import contextmanager import soundfile as sf import numpy as np from math import gcd from scipy.signal import resample_poly @contextmanager def mrt_overrides(mrt, **kwargs): """Temporarily set attributes on MRT if they exist; restore after.""" old = {} try: for k, v in kwargs.items(): if hasattr(mrt, k): old[k] = getattr(mrt, k) setattr(mrt, k, v) yield finally: for k, v in old.items(): setattr(mrt, k, v) # loudness utils try: import pyloudnorm as pyln _HAS_LOUDNORM = True except Exception: _HAS_LOUDNORM = False def _measure_lufs(wav: au.Waveform) -> float: # pyloudnorm expects float32/float64, shape (n,) or (n, ch) meter = pyln.Meter(wav.sample_rate) # defaults to BS.1770-4 return float(meter.integrated_loudness(wav.samples)) def _rms(x: np.ndarray) -> float: if x.size == 0: return 0.0 return float(np.sqrt(np.mean(x**2))) def match_loudness_to_reference( ref: au.Waveform, target: au.Waveform, method: str = "auto", # "auto"|"lufs"|"rms"|"none" headroom_db: float = 1.0 ) -> tuple[au.Waveform, dict]: """ Scales `target` to match `ref` loudness. Returns (adjusted_wave, stats). """ stats = {"method": method, "applied_gain_db": 0.0} if method == "none": return target, stats if method == "auto": method = "lufs" if _HAS_LOUDNORM else "rms" if method == "lufs" and _HAS_LOUDNORM: L_ref = _measure_lufs(ref) L_tgt = _measure_lufs(target) delta_db = L_ref - L_tgt gain = 10.0 ** (delta_db / 20.0) y = target.samples.astype(np.float32) * gain stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db}) else: # RMS fallback ra = _rms(ref.samples) rb = _rms(target.samples) if rb <= 1e-12: return target, stats gain = ra / rb y = target.samples.astype(np.float32) * gain stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))}) # simple peak “limiter” to keep headroom limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS peak = float(np.max(np.abs(y))) if y.size else 0.0 if peak > limit: y *= (limit / peak) stats["post_peak_limited"] = True else: stats["post_peak_limited"] = False target.samples = y.astype(np.float32) return target, stats # ---------------------------- # Crossfade stitch (your good path) # ---------------------------- def stitch_generated(chunks, sr, xfade_s): if not chunks: raise ValueError("no chunks") xfade_n = int(round(xfade_s * sr)) if xfade_n <= 0: return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr) t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32) eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None] first = chunks[0].samples if first.shape[0] < xfade_n: raise ValueError("chunk shorter than crossfade prefix") out = first[xfade_n:].copy() # drop model pre-roll for i in range(1, len(chunks)): cur = chunks[i].samples if cur.shape[0] < xfade_n: continue head, tail = cur[:xfade_n], cur[xfade_n:] mixed = out[-xfade_n:] * eq_out + head * eq_in out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0) return au.Waveform(out, sr) # ---------------------------- # Bar-aligned token context # ---------------------------- def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4): frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps frames_per_bar = int(round(frames_per_bar_f)) if abs(frames_per_bar - frames_per_bar_f) > 1e-3: reps = int(np.ceil(ctx_frames / len(tokens))) return np.tile(tokens, (reps, 1))[-ctx_frames:] reps = int(np.ceil(ctx_frames / len(tokens))) tiled = np.tile(tokens, (reps, 1)) end = (len(tiled) // frames_per_bar) * frames_per_bar if end < ctx_frames: return tiled[-ctx_frames:] start = end - ctx_frames return tiled[start:end] def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform: n = int(round(seconds * wav.sample_rate)) return au.Waveform(wav.samples[:n], wav.sample_rate) def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None: n = int(wav.sample_rate * ms / 1000.0) if n > 0 and wav.samples.shape[0] > 2*n: env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None] wav.samples[:n] *= env wav.samples[-n:] *= env[::-1] # ---------------------------- # Main generation (single combined style vector) # ---------------------------- def generate_loop_continuation_with_mrt( mrt, input_wav_path: str, bpm: float, extra_styles=None, style_weights=None, bars: int = 8, beats_per_bar: int = 4, loop_weight: float = 1.0, # NEW loudness_mode: str = "auto", # "auto"|"lufs"|"rms"|"none" loudness_headroom_db: float = 1.0, # for the peak guard ): # Load loop & encode loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() tokens_full = mrt.codec.encode(loop).astype(np.int32) tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] # Context context_tokens = make_bar_aligned_context( tokens, bpm=bpm, fps=int(mrt.codec.frame_rate), ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar, ) state = mrt.init_state() state.context_tokens = context_tokens # ---------- STYLE: weighted avg into ONE vector ---------- # Base embed from loop with adjustable loop_weight embeds = [] weights = [] # loop embedding loop_embed = mrt.embed_style(loop) embeds.append(loop_embed) weights.append(float(loop_weight)) # <--- use requested loop weight # extra styles if extra_styles: for i, s in enumerate(extra_styles): if s.strip(): embeds.append(mrt.embed_style(s.strip())) w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 weights.append(float(w)) # Prevent all-zero weights; normalize wsum = float(sum(weights)) if wsum <= 0.0: # fallback: rely on loop to avoid NaNs weights = [1.0] + [0.0] * (len(weights) - 1) wsum = 1.0 weights = [w / wsum for w in weights] # weighted sum -> single style vector (match dtype) combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) # Chunks to cover exact bars seconds_per_bar = beats_per_bar * (60.0 / bpm) total_secs = bars * seconds_per_bar chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0 steps = int(math.ceil(total_secs / chunk_secs)) + 1 # pad then trim # Generate chunks = [] for _ in range(steps): wav, state = mrt.generate_chunk(state=state, style=combined_style) # ONE style vector chunks.append(wav) # Stitch -> trim -> polish out = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() out = hard_trim_seconds(out, total_secs).peak_normalize(0.95) apply_micro_fades(out, 5) # Loudness match to the *input loop* so the return level feels consistent out, loud_stats = match_loudness_to_reference( ref=loop, target=out, method=loudness_mode, headroom_db=loudness_headroom_db, ) return out, loud_stats # ---------------------------- # FastAPI app with lazy, thread-safe model init # ---------------------------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # or lock to your domain(s) allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) _MRT = None _MRT_LOCK = threading.Lock() def get_mrt(): global _MRT if _MRT is None: with _MRT_LOCK: if _MRT is None: _MRT = system.MagentaRT(tag="base", guidance_weight=1.0, device="gpu", lazy=False) return _MRT @app.post("/generate") def generate( loop_audio: UploadFile = File(...), bpm: float = Form(...), bars: int = Form(8), beats_per_bar: int = Form(4), styles: str = Form("acid house"), style_weights: str = Form(""), loop_weight: float = Form(1.0), loudness_mode: str = Form("auto"), loudness_headroom_db: float = Form(1.0), guidance_weight: float = Form(5.0), temperature: float = Form(1.1), topk: int = Form(40), target_sample_rate: int | None = Form(None), # <-- add this ): # Read file data = loop_audio.file.read() if not data: return {"error": "Empty file"} with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(data) tmp_path = tmp.name # Parse styles + weights extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()] weights = [float(x) for x in style_weights.split(",")] if style_weights else None mrt = get_mrt() # warm once, in this worker thread # Temporarily override MRT inference knobs for this request with mrt_overrides(mrt, guidance_weight=guidance_weight, temperature=temperature, topk=topk): wav, loud_stats = generate_loop_continuation_with_mrt( mrt, input_wav_path=tmp_path, bpm=bpm, extra_styles=extra_styles, style_weights=weights, bars=bars, beats_per_bar=beats_per_bar, loop_weight=loop_weight, loudness_mode=loudness_mode, loudness_headroom_db=loudness_headroom_db, ) # 1) Figure out the desired SR inp_info = sf.info(tmp_path) input_sr = int(inp_info.samplerate) target_sr = int(target_sample_rate or input_sr) # 2) Convert magenta output to target_sr if needed # wav.samples: shape [num_samples, num_channels], float32/-1..1 (per your code) cur_sr = int(mrt.sample_rate) x = wav.samples # np.ndarray (S, C) if cur_sr != target_sr: g = gcd(cur_sr, target_sr) up, down = target_sr // g, cur_sr // g # ensure 2D shape (S, C) x = wav.samples if x.ndim == 1: x = x[:, None] y = np.column_stack([resample_poly(x[:, ch], up, down) for ch in range(x.shape[1])]) else: y = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None] # 3) Snap to exact frame count for loop-perfect length seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar) expected_len = int(round(float(bars) * seconds_per_bar * target_sr)) if y.shape[0] < expected_len: pad = np.zeros((expected_len - y.shape[0], y.shape[1]), dtype=y.dtype) y = np.vstack([y, pad]) elif y.shape[0] > expected_len: y = y[:expected_len, :] total_samples = int(y.shape[0]) loop_duration_seconds = total_samples / float(target_sr) # 4) Write y into buf as WAV @ target_sr buf = io.BytesIO() sf.write(buf, y, target_sr, subtype="FLOAT", format="WAV") buf.seek(0) audio_b64 = base64.b64encode(buf.read()).decode("utf-8") # 5) Update metadata to be authoritative metadata = { "bpm": int(round(bpm)), "bars": int(bars), "beats_per_bar": int(beats_per_bar), "styles": extra_styles, "style_weights": weights, "loop_weight": loop_weight, "loudness": loud_stats, "sample_rate": int(target_sr), "channels": int(y.shape[1]), "crossfade_seconds": mrt.config.crossfade_length, "total_samples": total_samples, "seconds_per_bar": seconds_per_bar, "loop_duration_seconds": loop_duration_seconds, "guidance_weight": guidance_weight, "temperature": temperature, "topk": topk, } return {"audio_base64": audio_b64, "metadata": metadata} @app.get("/health") def health(): return {"ok": True}