from magenta_rt import system, audio as au import numpy as np from fastapi import FastAPI, UploadFile, File, Form import tempfile, io, base64, math, threading from fastapi.middleware.cors import CORSMiddleware # loudness utils try: import pyloudnorm as pyln _HAS_LOUDNORM = True except Exception: _HAS_LOUDNORM = False def _measure_lufs(wav: au.Waveform) -> float: # pyloudnorm expects float32/float64, shape (n,) or (n, ch) meter = pyln.Meter(wav.sample_rate) # defaults to BS.1770-4 return float(meter.integrated_loudness(wav.samples)) def _rms(x: np.ndarray) -> float: if x.size == 0: return 0.0 return float(np.sqrt(np.mean(x**2))) def match_loudness_to_reference( ref: au.Waveform, target: au.Waveform, method: str = "auto", # "auto"|"lufs"|"rms"|"none" headroom_db: float = 1.0 ) -> tuple[au.Waveform, dict]: """ Scales `target` to match `ref` loudness. Returns (adjusted_wave, stats). """ stats = {"method": method, "applied_gain_db": 0.0} if method == "none": return target, stats if method == "auto": method = "lufs" if _HAS_LOUDNORM else "rms" if method == "lufs" and _HAS_LOUDNORM: L_ref = _measure_lufs(ref) L_tgt = _measure_lufs(target) delta_db = L_ref - L_tgt gain = 10.0 ** (delta_db / 20.0) y = target.samples.astype(np.float32) * gain stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db}) else: # RMS fallback ra = _rms(ref.samples) rb = _rms(target.samples) if rb <= 1e-12: return target, stats gain = ra / rb y = target.samples.astype(np.float32) * gain stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))}) # simple peak “limiter” to keep headroom limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS peak = float(np.max(np.abs(y))) if y.size else 0.0 if peak > limit: y *= (limit / peak) stats["post_peak_limited"] = True else: stats["post_peak_limited"] = False target.samples = y.astype(np.float32) return target, stats # ---------------------------- # Crossfade stitch (your good path) # ---------------------------- def stitch_generated(chunks, sr, xfade_s): if not chunks: raise ValueError("no chunks") xfade_n = int(round(xfade_s * sr)) if xfade_n <= 0: return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr) t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32) eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None] first = chunks[0].samples if first.shape[0] < xfade_n: raise ValueError("chunk shorter than crossfade prefix") out = first[xfade_n:].copy() # drop model pre-roll for i in range(1, len(chunks)): cur = chunks[i].samples if cur.shape[0] < xfade_n: continue head, tail = cur[:xfade_n], cur[xfade_n:] mixed = out[-xfade_n:] * eq_out + head * eq_in out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0) return au.Waveform(out, sr) # ---------------------------- # Bar-aligned token context # ---------------------------- def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4): frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps frames_per_bar = int(round(frames_per_bar_f)) if abs(frames_per_bar - frames_per_bar_f) > 1e-3: reps = int(np.ceil(ctx_frames / len(tokens))) return np.tile(tokens, (reps, 1))[-ctx_frames:] reps = int(np.ceil(ctx_frames / len(tokens))) tiled = np.tile(tokens, (reps, 1)) end = (len(tiled) // frames_per_bar) * frames_per_bar if end < ctx_frames: return tiled[-ctx_frames:] start = end - ctx_frames return tiled[start:end] def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform: n = int(round(seconds * wav.sample_rate)) return au.Waveform(wav.samples[:n], wav.sample_rate) def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None: n = int(wav.sample_rate * ms / 1000.0) if n > 0 and wav.samples.shape[0] > 2*n: env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None] wav.samples[:n] *= env wav.samples[-n:] *= env[::-1] # ---------------------------- # Main generation (single combined style vector) # ---------------------------- def generate_loop_continuation_with_mrt( mrt, input_wav_path: str, bpm: float, extra_styles=None, style_weights=None, bars: int = 8, beats_per_bar: int = 4, loop_weight: float = 1.0, # NEW loudness_mode: str = "auto", # "auto"|"lufs"|"rms"|"none" loudness_headroom_db: float = 1.0, # for the peak guard ): # Load loop & encode loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() tokens_full = mrt.codec.encode(loop).astype(np.int32) tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] # Context context_tokens = make_bar_aligned_context( tokens, bpm=bpm, fps=int(mrt.codec.frame_rate), ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar, ) state = mrt.init_state() state.context_tokens = context_tokens # ---------- STYLE: weighted avg into ONE vector ---------- # Base embed from loop with adjustable loop_weight embeds = [] weights = [] # loop embedding loop_embed = mrt.embed_style(loop) embeds.append(loop_embed) weights.append(float(loop_weight)) # <--- use requested loop weight # extra styles if extra_styles: for i, s in enumerate(extra_styles): if s.strip(): embeds.append(mrt.embed_style(s.strip())) w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 weights.append(float(w)) # Prevent all-zero weights; normalize wsum = float(sum(weights)) if wsum <= 0.0: # fallback: rely on loop to avoid NaNs weights = [1.0] + [0.0] * (len(weights) - 1) wsum = 1.0 weights = [w / wsum for w in weights] # weighted sum -> single style vector (match dtype) combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) # Chunks to cover exact bars seconds_per_bar = beats_per_bar * (60.0 / bpm) total_secs = bars * seconds_per_bar chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0 steps = int(math.ceil(total_secs / chunk_secs)) + 1 # pad then trim # Generate chunks = [] for _ in range(steps): wav, state = mrt.generate_chunk(state=state, style=combined_style) # ONE style vector chunks.append(wav) # Stitch -> trim -> polish out = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() out = hard_trim_seconds(out, total_secs).peak_normalize(0.95) apply_micro_fades(out, 5) # Loudness match to the *input loop* so the return level feels consistent out, loud_stats = match_loudness_to_reference( ref=loop, target=out, method=loudness_mode, headroom_db=loudness_headroom_db, ) return out, loud_stats # ---------------------------- # FastAPI app with lazy, thread-safe model init # ---------------------------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # or lock to your domain(s) allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) _MRT = None _MRT_LOCK = threading.Lock() def get_mrt(): global _MRT if _MRT is None: with _MRT_LOCK: if _MRT is None: _MRT = system.MagentaRT(tag="base", guidance_weight=1.0, device="gpu", lazy=False) return _MRT @app.post("/generate") def generate( loop_audio: UploadFile = File(...), bpm: float = Form(...), bars: int = Form(8), beats_per_bar: int = Form(4), styles: str = Form("acid house"), style_weights: str = Form(""), loop_weight: float = Form(1.0), # NEW loudness_mode: str = Form("auto"), # NEW loudness_headroom_db: float = Form(1.0), # NEW ): # Read file data = loop_audio.file.read() if not data: return {"error": "Empty file"} with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(data) tmp_path = tmp.name # Parse styles + weights extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()] weights = [float(x) for x in style_weights.split(",")] if style_weights else None mrt = get_mrt() # warm once, in this worker thread mrt = get_mrt() wav, loud_stats = generate_loop_continuation_with_mrt( mrt, input_wav_path=tmp_path, bpm=bpm, extra_styles=extra_styles, style_weights=weights, bars=bars, beats_per_bar=beats_per_bar, loop_weight=loop_weight, loudness_mode=loudness_mode, loudness_headroom_db=loudness_headroom_db, ) # Return base64 WAV + minimal metadata buf = io.BytesIO() # add format="WAV" when writing to a file-like object wav.write(buf, subtype="FLOAT", format="WAV") buf.seek(0) audio_b64 = base64.b64encode(buf.read()).decode("utf-8") return { "audio_base64": audio_b64, "metadata": { "bpm": int(round(bpm)), "bars": int(bars), "beats_per_bar": int(beats_per_bar), "styles": extra_styles, "style_weights": weights, "loop_weight": loop_weight, "loudness": loud_stats, # NEW "sample_rate": mrt.sample_rate, "channels": mrt.num_channels, "crossfade_seconds": mrt.config.crossfade_length, }, } @app.get("/health") def health(): return {"ok": True}