from magenta_rt import system, audio as au import numpy as np from fastapi import FastAPI, UploadFile, File, Form import tempfile, io, base64, math, threading from fastapi.middleware.cors import CORSMiddleware from contextlib import contextmanager import soundfile as sf import numpy as np from math import gcd from scipy.signal import resample_poly @contextmanager def mrt_overrides(mrt, **kwargs): """Temporarily set attributes on MRT if they exist; restore after.""" old = {} try: for k, v in kwargs.items(): if hasattr(mrt, k): old[k] = getattr(mrt, k) setattr(mrt, k, v) yield finally: for k, v in old.items(): setattr(mrt, k, v) # loudness utils try: import pyloudnorm as pyln _HAS_LOUDNORM = True except Exception: _HAS_LOUDNORM = False def _measure_lufs(wav: au.Waveform) -> float: # pyloudnorm expects float32/float64, shape (n,) or (n, ch) meter = pyln.Meter(wav.sample_rate) # defaults to BS.1770-4 return float(meter.integrated_loudness(wav.samples)) def _rms(x: np.ndarray) -> float: if x.size == 0: return 0.0 return float(np.sqrt(np.mean(x**2))) def match_loudness_to_reference( ref: au.Waveform, target: au.Waveform, method: str = "auto", # "auto"|"lufs"|"rms"|"none" headroom_db: float = 1.0 ) -> tuple[au.Waveform, dict]: """ Scales `target` to match `ref` loudness. Returns (adjusted_wave, stats). """ stats = {"method": method, "applied_gain_db": 0.0} if method == "none": return target, stats if method == "auto": method = "lufs" if _HAS_LOUDNORM else "rms" if method == "lufs" and _HAS_LOUDNORM: L_ref = _measure_lufs(ref) L_tgt = _measure_lufs(target) delta_db = L_ref - L_tgt gain = 10.0 ** (delta_db / 20.0) y = target.samples.astype(np.float32) * gain stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db}) else: # RMS fallback ra = _rms(ref.samples) rb = _rms(target.samples) if rb <= 1e-12: return target, stats gain = ra / rb y = target.samples.astype(np.float32) * gain stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))}) # simple peak “limiter” to keep headroom limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS peak = float(np.max(np.abs(y))) if y.size else 0.0 if peak > limit: y *= (limit / peak) stats["post_peak_limited"] = True else: stats["post_peak_limited"] = False target.samples = y.astype(np.float32) return target, stats # ---------------------------- # Crossfade stitch (your good path) # ---------------------------- def stitch_generated(chunks, sr, xfade_s): if not chunks: raise ValueError("no chunks") xfade_n = int(round(xfade_s * sr)) if xfade_n <= 0: return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr) t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32) eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None] first = chunks[0].samples if first.shape[0] < xfade_n: raise ValueError("chunk shorter than crossfade prefix") out = first[xfade_n:].copy() # drop model pre-roll for i in range(1, len(chunks)): cur = chunks[i].samples if cur.shape[0] < xfade_n: continue head, tail = cur[:xfade_n], cur[xfade_n:] mixed = out[-xfade_n:] * eq_out + head * eq_in out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0) return au.Waveform(out, sr) # ---------------------------- # Bar-aligned token context # ---------------------------- def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4): frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps frames_per_bar = int(round(frames_per_bar_f)) if abs(frames_per_bar - frames_per_bar_f) > 1e-3: reps = int(np.ceil(ctx_frames / len(tokens))) return np.tile(tokens, (reps, 1))[-ctx_frames:] reps = int(np.ceil(ctx_frames / len(tokens))) tiled = np.tile(tokens, (reps, 1)) end = (len(tiled) // frames_per_bar) * frames_per_bar if end < ctx_frames: return tiled[-ctx_frames:] start = end - ctx_frames return tiled[start:end] def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform: n = int(round(seconds * wav.sample_rate)) return au.Waveform(wav.samples[:n], wav.sample_rate) def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None: n = int(wav.sample_rate * ms / 1000.0) if n > 0 and wav.samples.shape[0] > 2*n: env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None] wav.samples[:n] *= env wav.samples[-n:] *= env[::-1] def take_bar_aligned_tail(wav, bpm, beats_per_bar, ctx_seconds, max_bars=None): """ Return the LAST N bars whose duration is as close as possible to ctx_seconds, anchored to the end of `wav`, and bar-aligned. """ spb = (60.0 / bpm) * beats_per_bar bars_needed = max(1, int(round(ctx_seconds / spb))) if max_bars is not None: bars_needed = min(bars_needed, max_bars) tail_seconds = bars_needed * spb n = int(round(tail_seconds * wav.sample_rate)) if n >= wav.samples.shape[0]: return wav return au.Waveform(wav.samples[-n:], wav.sample_rate) # ---------------------------- # Main generation (single combined style vector) # ---------------------------- def generate_loop_continuation_with_mrt( mrt, input_wav_path: str, bpm: float, extra_styles=None, style_weights=None, bars: int = 8, beats_per_bar: int = 4, loop_weight: float = 1.0, loudness_mode: str = "auto", loudness_headroom_db: float = 1.0, intro_bars_to_drop: int = 0, # <— NEW ): # Load & prep (unchanged) loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() # Use tail for context (your recent change) codec_fps = float(mrt.codec.frame_rate) ctx_seconds = float(mrt.config.context_length_frames) / codec_fps loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32) tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] # Bar-aligned token window (unchanged) context_tokens = make_bar_aligned_context( tokens, bpm=bpm, fps=int(mrt.codec.frame_rate), ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar ) state = mrt.init_state() state.context_tokens = context_tokens # STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias) loop_embed = mrt.embed_style(loop_for_context) embeds, weights = [loop_embed], [float(loop_weight)] if extra_styles: for i, s in enumerate(extra_styles): if s.strip(): embeds.append(mrt.embed_style(s.strip())) w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 weights.append(float(w)) wsum = float(sum(weights)) or 1.0 weights = [w / wsum for w in weights] combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) # --- Length math --- seconds_per_bar = beats_per_bar * (60.0 / bpm) total_secs = bars * seconds_per_bar drop_bars = max(0, int(intro_bars_to_drop)) drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars gen_total_secs = total_secs + drop_secs # generate extra # Chunk scheduling to cover gen_total_secs chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0 steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim # Generate chunks = [] for _ in range(steps): wav, state = mrt.generate_chunk(state=state, style=combined_style) chunks.append(wav) # Stitch continuous audio stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() # Trim to generated length (bars + dropped bars) stitched = hard_trim_seconds(stitched, gen_total_secs) # 👉 Drop the intro bars if drop_secs > 0: n_drop = int(round(drop_secs * stitched.sample_rate)) stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) # Final exact-length trim to requested bars out = hard_trim_seconds(stitched, total_secs) # Final polish AFTER drop out = out.peak_normalize(0.95) apply_micro_fades(out, 5) # Loudness match to input (after drop) so bar 1 sits right out, loud_stats = match_loudness_to_reference( ref=loop, target=out, method=loudness_mode, headroom_db=loudness_headroom_db ) return out, loud_stats # ---------------------------- # FastAPI app with lazy, thread-safe model init # ---------------------------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # or lock to your domain(s) allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) _MRT = None _MRT_LOCK = threading.Lock() def get_mrt(): global _MRT if _MRT is None: with _MRT_LOCK: if _MRT is None: _MRT = system.MagentaRT(tag="base", guidance_weight=1.0, device="gpu", lazy=False) return _MRT @app.post("/generate") def generate( loop_audio: UploadFile = File(...), bpm: float = Form(...), bars: int = Form(8), beats_per_bar: int = Form(4), styles: str = Form("acid house"), style_weights: str = Form(""), loop_weight: float = Form(1.0), loudness_mode: str = Form("auto"), loudness_headroom_db: float = Form(1.0), guidance_weight: float = Form(5.0), temperature: float = Form(1.1), topk: int = Form(40), target_sample_rate: int | None = Form(None), intro_bars_to_drop: int = Form(0), # <— NEW ): # Read file data = loop_audio.file.read() if not data: return {"error": "Empty file"} with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(data) tmp_path = tmp.name # Parse styles + weights extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()] weights = [float(x) for x in style_weights.split(",")] if style_weights else None mrt = get_mrt() # warm once, in this worker thread # Temporarily override MRT inference knobs for this request with mrt_overrides(mrt, guidance_weight=guidance_weight, temperature=temperature, topk=topk): wav, loud_stats = generate_loop_continuation_with_mrt( mrt, input_wav_path=tmp_path, bpm=bpm, extra_styles=extra_styles, style_weights=weights, bars=bars, beats_per_bar=beats_per_bar, loop_weight=loop_weight, loudness_mode=loudness_mode, loudness_headroom_db=loudness_headroom_db, intro_bars_to_drop=intro_bars_to_drop, # <— pass through ) # 1) Figure out the desired SR inp_info = sf.info(tmp_path) input_sr = int(inp_info.samplerate) target_sr = int(target_sample_rate or input_sr) # 2) Convert magenta output to target_sr if needed # wav.samples: shape [num_samples, num_channels], float32/-1..1 (per your code) cur_sr = int(mrt.sample_rate) x = wav.samples # np.ndarray (S, C) if cur_sr != target_sr: g = gcd(cur_sr, target_sr) up, down = target_sr // g, cur_sr // g # ensure 2D shape (S, C) x = wav.samples if x.ndim == 1: x = x[:, None] y = np.column_stack([resample_poly(x[:, ch], up, down) for ch in range(x.shape[1])]) else: y = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None] # 3) Snap to exact frame count for loop-perfect length seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar) expected_len = int(round(float(bars) * seconds_per_bar * target_sr)) if y.shape[0] < expected_len: pad = np.zeros((expected_len - y.shape[0], y.shape[1]), dtype=y.dtype) y = np.vstack([y, pad]) elif y.shape[0] > expected_len: y = y[:expected_len, :] total_samples = int(y.shape[0]) loop_duration_seconds = total_samples / float(target_sr) # 4) Write y into buf as WAV @ target_sr buf = io.BytesIO() sf.write(buf, y, target_sr, subtype="FLOAT", format="WAV") buf.seek(0) audio_b64 = base64.b64encode(buf.read()).decode("utf-8") # 5) Update metadata to be authoritative metadata = { "bpm": int(round(bpm)), "bars": int(bars), "beats_per_bar": int(beats_per_bar), "styles": extra_styles, "style_weights": weights, "loop_weight": loop_weight, "loudness": loud_stats, "sample_rate": int(target_sr), "channels": int(y.shape[1]), "crossfade_seconds": mrt.config.crossfade_length, "total_samples": total_samples, "seconds_per_bar": seconds_per_bar, "loop_duration_seconds": loop_duration_seconds, "guidance_weight": guidance_weight, "temperature": temperature, "topk": topk, } return {"audio_base64": audio_b64, "metadata": metadata} @app.get("/health") def health(): return {"ok": True}