#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import sys import gc import re import json import time import mmap import math import torch import random import logging import warnings import traceback import subprocess import numpy as np import torchaudio import gradio as gr import gradio_client.utils from pydub import AudioSegment from datetime import datetime from pathlib import Path from typing import Optional, Tuple, Dict, Any, List from torch.cuda.amp import autocast from fastapi import FastAPI, HTTPException, Body from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn import threading # ====================================================================================== # PATCHES & RUNTIME SETUP # ====================================================================================== # Gradio schema bool patch (prevents crash for boolean schemas) _original_get_type = gradio_client.utils.get_type def _patched_get_type(schema): if isinstance(schema, bool): return "boolean" return _original_get_type(schema) gradio_client.utils.get_type = _patched_get_type # Warnings warnings.filterwarnings("ignore") # Allocator for CUDA 12.x os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" # Determinism/Benchmark settings torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # Logging LOG_DIR = "logs" os.makedirs(LOG_DIR, exist_ok=True) LOG_FILE = os.path.join(LOG_DIR, f"musicgen_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") logging.basicConfig( level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.FileHandler(LOG_FILE), logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger("ghostai-musicgen") # Device DEVICE = "cuda" if torch.cuda.is_available() else "cpu" if DEVICE != "cuda": logger.error("CUDA is required. Exiting.") sys.exit(1) logger.info(f"GPU: {torch.cuda.get_device_name(0)}") logger.info("Precision: fp16 model, fp32 CPU audio ops") # ====================================================================================== # SETTINGS PERSISTENCE # ====================================================================================== SETTINGS_FILE = "settings.json" DEFAULT_SETTINGS: Dict[str, Any] = { "cfg_scale": 5.8, "top_k": 250, # more creative search space "top_p": 0.95, # user requested higher probability cap "temperature": 0.90, # user requested ~0.9 "total_duration": 60, # default to 1 minute "bpm": 120, "drum_beat": "none", "synthesizer": "none", "rhythmic_steps": "none", "bass_style": "none", "guitar_style": "none", "target_volume": -23.0, "preset": "default", "max_steps": 1500, # keep for UI, chunking now fixed to 30s "bitrate": "192k", "output_sample_rate": "48000", "bit_depth": "16", "instrumental_prompt": "" } def load_settings_from_file() -> Dict[str, Any]: try: if os.path.exists(SETTINGS_FILE): with open(SETTINGS_FILE, "r") as f: data = json.load(f) # ensure all defaults present for k, v in DEFAULT_SETTINGS.items(): data.setdefault(k, v) logger.info(f"Loaded settings from {SETTINGS_FILE}") return data except Exception as e: logger.error(f"Failed reading {SETTINGS_FILE}: {e}") return DEFAULT_SETTINGS.copy() def save_settings_to_file(settings: Dict[str, Any]) -> None: try: with open(SETTINGS_FILE, "w") as f: json.dump(settings, f, indent=2) logger.info(f"Saved settings to {SETTINGS_FILE}") except Exception as e: logger.error(f"Failed saving {SETTINGS_FILE}: {e}") CURRENT_SETTINGS = load_settings_from_file() # ====================================================================================== # VRAM / DISK / MEMORY # ====================================================================================== def clean_memory() -> Optional[float]: try: torch.cuda.empty_cache() gc.collect() torch.cuda.ipc_collect() torch.cuda.synchronize() vram_mb = torch.cuda.memory_allocated() / 1024**2 logger.info(f"Memory cleaned. VRAM={vram_mb:.2f} MB") return vram_mb except Exception as e: logger.error(f"clean_memory failed: {e}") logger.error(traceback.format_exc()) return None def check_vram(): try: r = subprocess.run( ['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv'], capture_output=True, text=True ) lines = r.stdout.splitlines() if len(lines) > 1: used_mb, total_mb = map(int, re.findall(r'\d+', lines[1])) free_mb = total_mb - used_mb logger.info(f"VRAM: used {used_mb} MiB | free {free_mb} MiB | total {total_mb} MiB") if free_mb < 5000: logger.warning(f"Low free VRAM ({free_mb} MiB). Running processes:") procs = subprocess.run( ['nvidia-smi', '--query-compute-apps=pid,used_memory', '--format=csv'], capture_output=True, text=True ) logger.info(f"\n{procs.stdout}") return free_mb except Exception as e: logger.error(f"check_vram failed: {e}") return None def check_disk_space(path=".") -> bool: try: stat = os.statvfs(path) free_gb = stat.f_bavail * stat.f_frsize / (1024**3) if free_gb < 1.0: logger.warning(f"Low disk space: {free_gb:.2f} GB") return free_gb >= 1.0 except Exception as e: logger.error(f"Disk space check failed: {e}") return False # ====================================================================================== # AUDIO UTILS (CPU) # ====================================================================================== def ensure_stereo(audio_segment: AudioSegment, sample_rate=48000, sample_width=2) -> AudioSegment: try: if audio_segment.channels != 2: audio_segment = audio_segment.set_channels(2) if audio_segment.frame_rate != sample_rate: audio_segment = audio_segment.set_frame_rate(sample_rate) return audio_segment except Exception as e: logger.error(f"ensure_stereo failed: {e}") return audio_segment def calculate_rms(segment: AudioSegment) -> float: try: samples = np.array(segment.get_array_of_samples(), dtype=np.float32) rms = float(np.sqrt(np.mean(samples**2))) return rms except Exception as e: logger.error(f"calculate_rms failed: {e}") return 0.0 def hard_limit(audio_segment: AudioSegment, limit_db=-3.0, sample_rate=48000) -> AudioSegment: try: audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) limit = 10 ** (limit_db / 20.0) * (2**23 if audio_segment.sample_width == 3 else 32767) samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) samples = np.clip(samples, -limit, limit).astype(np.int32 if audio_segment.sample_width == 3 else np.int16) if len(samples) % 2 != 0: samples = samples[:-1] return AudioSegment( samples.tobytes(), frame_rate=sample_rate, sample_width=audio_segment.sample_width, channels=2 ) except Exception as e: logger.error(f"hard_limit failed: {e}") return audio_segment def rms_normalize(segment: AudioSegment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=48000) -> AudioSegment: try: segment = ensure_stereo(segment, sample_rate, segment.sample_width) target_rms = 10 ** (target_rms_db / 20) * (2**23 if segment.sample_width == 3 else 32767) current_rms = calculate_rms(segment) if current_rms > 0: gain_factor = target_rms / current_rms segment = segment.apply_gain(20 * np.log10(max(gain_factor, 1e-6))) segment = hard_limit(segment, limit_db=peak_limit_db, sample_rate=sample_rate) return segment except Exception as e: logger.error(f"rms_normalize failed: {e}") return segment def balance_stereo(audio_segment: AudioSegment, noise_threshold=-40, sample_rate=48000) -> AudioSegment: try: audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) if audio_segment.channels != 2: return audio_segment stereo = samples.reshape(-1, 2) db = 20 * np.log10(np.abs(stereo) + 1e-10) mask = db > noise_threshold stereo = stereo * mask left = stereo[:, 0] right = stereo[:, 1] l_rms = np.sqrt(np.mean(left[left != 0] ** 2)) if np.any(left != 0) else 0 r_rms = np.sqrt(np.mean(right[right != 0] ** 2)) if np.any(right != 0) else 0 if l_rms > 0 and r_rms > 0: avg = (l_rms + r_rms) / 2 stereo[:, 0] *= (avg / l_rms) stereo[:, 1] *= (avg / r_rms) out = stereo.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16) if len(out) % 2 != 0: out = out[:-1] return AudioSegment( out.tobytes(), frame_rate=sample_rate, sample_width=audio_segment.sample_width, channels=2 ) except Exception as e: logger.error(f"balance_stereo failed: {e}") return audio_segment def apply_noise_gate(audio_segment: AudioSegment, threshold_db=-80, sample_rate=48000) -> AudioSegment: try: audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) if audio_segment.channels != 2: return audio_segment stereo = samples.reshape(-1, 2) for _ in range(2): db = 20 * np.log10(np.abs(stereo) + 1e-10) mask = db > threshold_db stereo = stereo * mask out = stereo.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16) if len(out) % 2 != 0: out = out[:-1] return AudioSegment( out.tobytes(), frame_rate=sample_rate, sample_width=audio_segment.sample_width, channels=2 ) except Exception as e: logger.error(f"apply_noise_gate failed: {e}") return audio_segment def apply_eq(segment: AudioSegment, sample_rate=48000) -> AudioSegment: try: segment = ensure_stereo(segment, sample_rate, segment.sample_width) segment = segment.high_pass_filter(20) segment = segment.low_pass_filter(8000) segment = segment - 3 segment = segment - 3 segment = segment - 10 return segment except Exception as e: logger.error(f"apply_eq failed: {e}") return segment def apply_fade(segment: AudioSegment, fade_in_duration=500, fade_out_duration=500) -> AudioSegment: try: segment = ensure_stereo(segment, segment.frame_rate, segment.sample_width) segment = segment.fade_in(fade_in_duration).fade_out(fade_out_duration) return segment except Exception as e: logger.error(f"apply_fade failed: {e}") return segment # ====================================================================================== # PROMPTS # ====================================================================================== def set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num): try: bpm_range = (90, 130) bpm = random.randint(*bpm_range) if bpm == 120 else bpm drum = f", {drum_beat} drums" if drum_beat != "none" else ", standard rock drums with funk fills" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style} bass" if bass_style != "none" else ", funky slap bass" guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", energetic guitar riffs" base = f"Instrumental alternative rock by Red Hot Chili Peppers{guitar}{bass}{drum}{synth}, funk-rock energy at {bpm} BPM" if chunk_num == 1: return base + ", dynamic intro and expressive verse." return base + ", powerful chorus and energetic outro." except Exception: return "" def set_nirvana_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: bpm_range = (100, 130) bpm = random.randint(*bpm_range) if bpm == 120 else bpm drum = f", {drum_beat} drums, punk energy" if drum_beat != "none" else ", standard rock drums, punk energy" synth = f", {synthesizer}" if synthesizer != "none" else "" chosen_bass = random.choice(['deep bass', 'melodic bass']) if bass_style == "none" else bass_style bass = f", {chosen_bass}" chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style guitar = f", {chosen_guitar}" chosen_rhythm = random.choice(['steady steps', 'dynamic shifts']) if rhythmic_steps == "none" else rhythmic_steps rhythm = f", {chosen_rhythm}" return f"Instrumental grunge by Nirvana{guitar}{bass}{drum}{synth}, raw lo-fi production{rhythm} at {bpm} BPM." except Exception: return "" def set_pearl_jam_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: bpm_range = (100, 140) bpm = random.randint(*bpm_range) if bpm == 120 else bpm drum = f", {drum_beat} drums, driving rhythm" if drum_beat != "none" else ", standard rock drums, driving rhythm" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style}, emotional tone" if bass_style != "none" else ", melodic bass, emotional tone" chosen_guitar = random.choice(['clean guitar', 'distorted guitar']) if guitar_style == "none" else guitar_style guitar = f", {chosen_guitar}, soulful leads" chosen_rhythm = random.choice(['steady steps', 'syncopated steps']) if rhythmic_steps == "none" else rhythmic_steps rhythm = f", {chosen_rhythm}" return f"Instrumental grunge by Pearl Jam{guitar}{bass}{drum}{synth}, classic rock influences{rhythm} at {bpm} BPM." except Exception: return "" def set_soundgarden_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: bpm_range = (90, 140) bpm = random.randint(*bpm_range) if bpm == 120 else bpm drum = f", {drum_beat} drums, heavy rhythm" if drum_beat != "none" else ", standard rock drums, heavy rhythm" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style}, sludgy tone" if bass_style != "none" else ", deep bass, sludgy tone" guitar = f", {guitar_style}, downtuned riffs, psychedelic vibe" if guitar_style != "none" else ", distorted guitar, downtuned riffs, psychedelic vibe" rhythm = f", {rhythmic_steps}" if rhythmic_steps != "none" else ", complex steps" return f"Instrumental grunge with heavy metal influences by Soundgarden{guitar}{bass}{drum}{synth}{rhythm} at {bpm} BPM." except Exception: return "" def set_foo_fighters_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: bpm_range = (110, 150) bpm = random.randint(*bpm_range) if bpm == 120 else bpm drum = f", {drum_beat} drums, powerful drive" if drum_beat != "none" else ", standard rock drums, powerful drive" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style}, supportive tone" if bass_style != "none" else ", melodic bass, supportive tone" chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style guitar = f", {chosen_guitar}, anthemic quality" chosen_rhythm = random.choice(['steady steps', 'driving rhythm']) if rhythmic_steps == "none" else rhythmic_steps rhythm = f", {chosen_rhythm}" return f"Instrumental alternative rock with post-grunge influences by Foo Fighters{guitar}, stadium-ready hooks{bass}{drum}{synth}{rhythm} at {bpm} BPM." except Exception: return "" def set_classic_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: bpm_range = (120, 180) bpm = random.randint(*bpm_range) if bpm == 120 else bpm drum = f", {drum_beat} drums" if drum_beat != "none" else ", double bass drums" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style}" if bass_style != "none" else ", aggressive bass" guitar = f", {guitar_style}, blazing fast riffs" if guitar_style != "none" else ", distorted guitar, blazing fast riffs" rhythm = f", {rhythmic_steps}" if rhythmic_steps != "none" else ", complex steps" return f"Instrumental thrash metal by Metallica{guitar}{bass}{drum}{synth}{rhythm} at {bpm} BPM." except Exception: return "" def set_smashing_pumpkins_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: drum = f", {drum_beat} drums" if drum_beat != "none" else "" synth = f", {synthesizer}" if synthesizer != "none" else ", lush synths" bass = f", {bass_style} bass" if bass_style != "none" else "" guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", dreamy guitar" return f"Instrumental alternative rock by Smashing Pumpkins{guitar}{synth}{drum}{bass} at {bpm} BPM." except Exception: return "" def set_radiohead_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: drum = f", {drum_beat} drums" if drum_beat != "none" else "" synth = f", {synthesizer}" if synthesizer != "none" else ", atmospheric synths" bass = f", {bass_style} bass" if bass_style != "none" else ", hypnotic bass" guitar = f", {guitar_style} guitar" if guitar_style != "none" else "" return f"Instrumental experimental rock by Radiohead{synth}{bass}{drum}{guitar} at {bpm} BPM." except Exception: return "" def set_alternative_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: drum = f", {drum_beat} drums" if drum_beat != "none" else "" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style} bass" if bass_style != "none" else ", melodic bass" guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", distorted guitar" return f"Instrumental alternative rock by Pixies{guitar}{bass}{drum}{synth} at {bpm} BPM." except Exception: return "" def set_post_punk_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: drum = f", {drum_beat} drums" if drum_beat != "none" else ", precise drums" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style} bass" if bass_style != "none" else ", driving bass" guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar" return f"Instrumental post-punk by Joy Division{guitar}{bass}{drum}{synth} at {bpm} BPM." except Exception: return "" def set_indie_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: drum = f", {drum_beat} drums" if drum_beat != "none" else "" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style} bass" if bass_style != "none" else ", groovy bass" guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar" return f"Instrumental indie rock by Arctic Monkeys{guitar}{bass}{drum}{synth} at {bpm} BPM." except Exception: return "" def set_funk_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: drum = f", {drum_beat} drums" if drum_beat != "none" else ", heavy drums" synth = f", {synthesizer}" if synthesizer != "none" else "" bass = f", {bass_style} bass" if bass_style != "none" else ", slap bass" guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", funky guitar" return f"Instrumental funk rock by Rage Against the Machine{guitar}{bass}{drum}{synth} at {bpm} BPM." except Exception: return "" def set_detroit_techno_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: drum = f", {drum_beat} drums" if drum_beat != "none" else ", four-on-the-floor drums" synth = f", {synthesizer}" if synthesizer != "none" else ", pulsing synths" bass = f", {bass_style} bass" if bass_style != "none" else ", driving bass" guitar = f", {guitar_style} guitar" if guitar_style != "none" else "" return f"Instrumental Detroit techno by Juan Atkins{synth}{bass}{drum}{guitar} at {bpm} BPM." except Exception: return "" def set_deep_house_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): try: drum = f", {drum_beat} drums" if drum_beat != "none" else ", steady kick drums" synth = f", {synthesizer}" if synthesizer != "none" else ", warm synths" bass = f", {bass_style} bass" if bass_style != "none" else ", deep bass" guitar = f", {guitar_style} guitar" if guitar_style != "none" else "" return f"Instrumental deep house by Larry Heard{synth}{bass}{drum}{guitar} at {bpm} BPM." except Exception: return "" PRESETS = { "default": {"cfg_scale": 5.8, "top_k": 250, "top_p": 0.95, "temperature": 0.90}, "rock": {"cfg_scale": 5.8, "top_k": 250, "top_p": 0.95, "temperature": 0.90}, "techno": {"cfg_scale": 5.2, "top_k": 300, "top_p": 0.96, "temperature": 0.95}, "grunge": {"cfg_scale": 6.2, "top_k": 220, "top_p": 0.94, "temperature": 0.90}, "indie": {"cfg_scale": 5.5, "top_k": 240, "top_p": 0.95, "temperature": 0.92}, "funk_rock": {"cfg_scale": 5.8, "top_k": 260, "top_p": 0.96, "temperature": 0.94}, } # ====================================================================================== # MODEL LOAD # ====================================================================================== try: from audiocraft.models import MusicGen except Exception as e: logger.error("audiocraft is required. pip install audiocraft") raise def load_model(): free_vram = check_vram() if free_vram is not None and free_vram < 5000: logger.warning("Low free VRAM; consider closing other apps.") clean_memory() local_model_path = "./models/musicgen-large" if not os.path.exists(local_model_path): logger.error(f"Model path missing: {local_model_path}") sys.exit(1) logger.info("Loading MusicGen (large)...") with autocast(dtype=torch.float16): model = MusicGen.get_pretrained(local_model_path, device=DEVICE) # base params get overridden per-call model.set_generation_params(duration=30, two_step_cfg=False) logger.info("MusicGen loaded.") return model musicgen_model = load_model() # ====================================================================================== # GENERATION PIPELINE (30s CHUNKING, SEAMLESS MERGE) # ====================================================================================== def get_latest_log() -> str: try: files = sorted(Path(LOG_DIR).glob("musicgen_log_*.log"), key=os.path.getmtime, reverse=True) if not files: return "No log files found." return files[0].read_text() except Exception as e: return f"Error reading log: {e}" def set_bitrate_128(): return "128k" def set_bitrate_192(): return "192k" def set_bitrate_320(): return "320k" def set_sample_rate_22050(): return "22050" def set_sample_rate_44100(): return "44100" def set_sample_rate_48000(): return "48000" def set_bit_depth_16(): return "16" def set_bit_depth_24(): return "24" def generate_music_wrapper(*args): try: return generate_music(*args) finally: clean_memory() def _export_torch_to_segment(audio_tensor: torch.Tensor, sample_rate: int, bit_depth_int: int) -> Optional[AudioSegment]: """Helper: save torch stereo float32 to WAV -> load with pydub as segment.""" temp = f"temp_audio_{int(time.time()*1000)}.wav" try: torchaudio.save(temp, audio_tensor, sample_rate, bits_per_sample=bit_depth_int) with open(temp, "rb") as f: mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) seg = AudioSegment.from_wav(temp) mm.close() return seg except Exception as e: logger.error(f"_export_torch_to_segment failed: {e}") logger.error(traceback.format_exc()) return None finally: try: if os.path.exists(temp): os.remove(temp) except OSError: pass def _crossfade_segments(seg_a: AudioSegment, seg_b: AudioSegment, overlap_ms: int, sample_rate: int, bit_depth_int: int) -> AudioSegment: """Blend tail of seg_a with head of seg_b using hann window for seamless merge.""" try: seg_a = ensure_stereo(seg_a, sample_rate, seg_a.sample_width) seg_b = ensure_stereo(seg_b, sample_rate, seg_b.sample_width) if overlap_ms <= 0 or len(seg_a) < overlap_ms or len(seg_b) < overlap_ms: return seg_a + seg_b # export overlaps prev_wav = f"tmp_prev_{int(time.time()*1000)}.wav" curr_wav = f"tmp_curr_{int(time.time()*1000)}.wav" try: seg_a[-overlap_ms:].export(prev_wav, format="wav") seg_b[:overlap_ms].export(curr_wav, format="wav") a_audio, sr_a = torchaudio.load(prev_wav) b_audio, sr_b = torchaudio.load(curr_wav) if sr_a != sample_rate: a_audio = torchaudio.functional.resample(a_audio, sr_a, sample_rate, lowpass_filter_width=64) if sr_b != sample_rate: b_audio = torchaudio.functional.resample(b_audio, sr_b, sample_rate, lowpass_filter_width=64) n = min(a_audio.shape[1], b_audio.shape[1]) n = n - (n % 2) if n <= 0: return seg_a + seg_b a = a_audio[:, :n] b = b_audio[:, :n] hann = torch.hann_window(n, periodic=False) fade_in = hann fade_out = hann.flip(0) blended = (a * fade_out + b * fade_in).to(torch.float32) blended = torch.clamp(blended, -1.0, 1.0) # scale to PCM and save scale = (2**23 if bit_depth_int == 24 else 32767) blended_i = (blended * scale).to(torch.int32 if bit_depth_int == 24 else torch.int16) temp_x = f"tmp_cross_{int(time.time()*1000)}.wav" torchaudio.save(temp_x, blended_i, sample_rate, bits_per_sample=bit_depth_int) blended_seg = AudioSegment.from_wav(temp_x) blended_seg = ensure_stereo(blended_seg, sample_rate, blended_seg.sample_width) # combine result = seg_a[:-overlap_ms] + blended_seg + seg_b[overlap_ms:] try: if os.path.exists(temp_x): os.remove(temp_x) except OSError: pass return result finally: for p in [prev_wav, curr_wav]: try: if os.path.exists(p): os.remove(p) except OSError: pass except Exception as e: logger.error(f"_crossfade_segments failed: {e}") return seg_a + seg_b def generate_music( instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, bpm: int, drum_beat: str, synthesizer: str, rhythmic_steps: str, bass_style: str, guitar_style: str, target_volume: float, preset: str, max_steps: str, # kept for UI parity vram_status_text: str, bitrate: str, output_sample_rate: str, bit_depth: str ) -> Tuple[Optional[str], str, str]: global musicgen_model if not instrumental_prompt or not instrumental_prompt.strip(): return None, "⚠️ Please enter a valid instrumental prompt!", vram_status_text try: # Apply preset if not default if preset != "default": p = PRESETS.get(preset, PRESETS["default"]) cfg_scale, top_k, top_p, temperature = p["cfg_scale"], p["top_k"], p["top_p"], p["temperature"] logger.info(f"Preset '{preset}' applied: cfg={cfg_scale} top_k={top_k} top_p={top_p} temp={temperature}") # Validate numerics try: output_sr_int = int(output_sample_rate) except: return None, "❌ Invalid output sampling rate; choose 22050/44100/48000", vram_status_text try: bit_depth_int = int(bit_depth) sample_width = 3 if bit_depth_int == 24 else 2 except: return None, "❌ Invalid bit depth; choose 16 or 24", vram_status_text if not check_disk_space(): return None, "⚠️ Low disk space (<1GB).", vram_status_text # Chunking: EXACT 30s per chunk (unify stepping -> always 30s). Two chunks => full 60s song. CHUNK_SEC = 30 total_duration = max(30, min(int(total_duration), 120)) num_chunks = math.ceil(total_duration / CHUNK_SEC) # Internal processing rate (resample to this for DSP) PROCESS_SR = 48000 OVERLAP_SEC = 0.20 # 200ms crossfade/prompt tail channels = 2 # Seed & params seed = random.randint(0, 2**31 - 1) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) musicgen_model.set_generation_params( duration=CHUNK_SEC, use_sampling=True, top_k=int(top_k), top_p=float(top_p), temperature=float(temperature), cfg_coef=float(cfg_scale), two_step_cfg=False, ) vram_status_text = f"Start VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" segments: List[AudioSegment] = [] start_time = time.time() for idx in range(num_chunks): chunk_idx = idx + 1 dur = CHUNK_SEC if (idx < num_chunks - 1) else (total_duration - CHUNK_SEC * (num_chunks - 1) or CHUNK_SEC) logger.info(f"Generating chunk {chunk_idx}/{num_chunks} ({dur}s)") # Prompt per chunk (variable for RHCP only) if "Red Hot Chili Peppers" in instrumental_prompt: prompt_text = set_red_hot_chili_peppers_prompt( bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_idx ) else: prompt_text = instrumental_prompt try: with torch.no_grad(): with autocast(dtype=torch.float16): clean_memory() if idx == 0: audio = musicgen_model.generate([prompt_text], progress=True)[0].cpu() else: # Use tail of previous segment as continuation prompt prev_seg = segments[-1] prev_seg = apply_noise_gate(prev_seg, threshold_db=-80, sample_rate=PROCESS_SR) prev_seg = balance_stereo(prev_seg, noise_threshold=-40, sample_rate=PROCESS_SR) temp_prev = f"prev_{int(time.time()*1000)}.wav" try: prev_seg.export(temp_prev, format="wav") prev_audio, prev_sr = torchaudio.load(temp_prev) if prev_sr != PROCESS_SR: prev_audio = torchaudio.functional.resample(prev_audio, prev_sr, PROCESS_SR, lowpass_filter_width=64) if prev_audio.shape[0] != 2: prev_audio = prev_audio.repeat(2, 1)[:, :prev_audio.shape[1]] prev_audio = prev_audio.to(DEVICE) tail = prev_audio[:, -int(PROCESS_SR * OVERLAP_SEC):] audio = musicgen_model.generate_continuation( prompt=tail, prompt_sample_rate=PROCESS_SR, descriptions=[prompt_text], progress=True )[0].cpu() del prev_audio, tail finally: try: if os.path.exists(temp_prev): os.remove(temp_prev) except OSError: pass clean_memory() except Exception as e: logger.error(f"Chunk {chunk_idx} generation failed: {e}") logger.error(traceback.format_exc()) return None, f"❌ Failed to generate chunk {chunk_idx}: {e}", vram_status_text try: # Ensure stereo & resample to PROCESS_SR for DSP if audio.shape[0] != 2: audio = audio.repeat(2, 1)[:, :audio.shape[1]] audio = audio.to(dtype=torch.float32) audio = torchaudio.functional.resample(audio, 32000, PROCESS_SR, lowpass_filter_width=64) seg = _export_torch_to_segment(audio, PROCESS_SR, bit_depth_int) if seg is None: return None, f"❌ Failed to convert audio for chunk {chunk_idx}", vram_status_text seg = ensure_stereo(seg, PROCESS_SR, sample_width) seg = seg - 15 seg = apply_noise_gate(seg, threshold_db=-80, sample_rate=PROCESS_SR) seg = balance_stereo(seg, noise_threshold=-40, sample_rate=PROCESS_SR) seg = rms_normalize(seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR) seg = apply_eq(seg, sample_rate=PROCESS_SR) # Trim exactly to 'dur' seconds for last chunk seg = seg[:dur * 1000] segments.append(seg) del audio clean_memory() vram_status_text = f"VRAM after chunk {chunk_idx}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" except Exception as e: logger.error(f"Post-processing failed (chunk {chunk_idx}): {e}") logger.error(traceback.format_exc()) return None, f"❌ Failed to process chunk {chunk_idx}: {e}", vram_status_text if not segments: return None, "❌ No audio generated.", vram_status_text # Seamless join with crossfades logger.info("Combining chunks...") final_seg = segments[0] overlap_ms = int(OVERLAP_SEC * 1000) for i in range(1, len(segments)): final_seg = _crossfade_segments(final_seg, segments[i], overlap_ms, PROCESS_SR, bit_depth_int) # Final length clamp final_seg = final_seg[:total_duration * 1000] # Final polish final_seg = apply_noise_gate(final_seg, threshold_db=-80, sample_rate=PROCESS_SR) final_seg = balance_stereo(final_seg, noise_threshold=-40, sample_rate=PROCESS_SR) final_seg = rms_normalize(final_seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR) final_seg = apply_eq(final_seg, sample_rate=PROCESS_SR) final_seg = apply_fade(final_seg, 500, 800) final_seg = final_seg - 10 final_seg = final_seg.set_frame_rate(output_sr_int) # Export MP3 mp3_path = f"ghostai_music_{int(time.time())}.mp3" try: clean_memory() final_seg.export(mp3_path, format="mp3", bitrate=bitrate, tags={"title": "GhostAI Instrumental", "artist": "GhostAI"}) except Exception as e: logger.error(f"MP3 export failed ({bitrate}): {e}") fb = f"ghostai_music_fallback_{int(time.time())}.mp3" try: final_seg.export(fb, format="mp3", bitrate="128k") mp3_path = fb except Exception as ee: return None, f"❌ Failed to export MP3: {ee}", vram_status_text elapsed = time.time() - start_time vram_status_text = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" logger.info(f"Done in {elapsed:.2f}s -> {mp3_path}") return mp3_path, "✅ Done! 30s chunking unified seamlessly. Check output loudness/quality.", vram_status_text except Exception as e: logger.error(f"Generation failed: {e}") logger.error(traceback.format_exc()) return None, f"❌ Generation failed: {e}", vram_status_text finally: clean_memory() def clear_inputs(): s = DEFAULT_SETTINGS.copy() return ( s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], s["bitrate"], s["output_sample_rate"], s["bit_depth"] ) # ====================================================================================== # SERVER STATUS (BUSY/IDLE) & RENDER API # ====================================================================================== BUSY_LOCK = threading.Lock() BUSY_FLAG = False BUSY_FILE = "/tmp/musicgen_busy.lock" CURRENT_JOB: Dict[str, Any] = {"id": None, "start": None} def set_busy(val: bool, job_id: Optional[str] = None): global BUSY_FLAG, CURRENT_JOB with BUSY_LOCK: BUSY_FLAG = val if val: CURRENT_JOB["id"] = job_id or f"job_{int(time.time())}" CURRENT_JOB["start"] = time.time() try: Path(BUSY_FILE).write_text(CURRENT_JOB["id"]) except Exception: pass else: CURRENT_JOB["id"] = None CURRENT_JOB["start"] = None try: if os.path.exists(BUSY_FILE): os.remove(BUSY_FILE) except Exception: pass def is_busy() -> bool: with BUSY_LOCK: return BUSY_FLAG def job_elapsed() -> float: with BUSY_LOCK: if CURRENT_JOB["start"] is None: return 0.0 return time.time() - CURRENT_JOB["start"] class RenderRequest(BaseModel): instrumental_prompt: str cfg_scale: Optional[float] = None top_k: Optional[int] = None top_p: Optional[float] = None temperature: Optional[float] = None total_duration: Optional[int] = None bpm: Optional[int] = None drum_beat: Optional[str] = None synthesizer: Optional[str] = None rhythmic_steps: Optional[str] = None bass_style: Optional[str] = None guitar_style: Optional[str] = None target_volume: Optional[float] = None preset: Optional[str] = None max_steps: Optional[int] = None bitrate: Optional[str] = None output_sample_rate: Optional[str] = None bit_depth: Optional[str] = None class SettingsUpdate(BaseModel): settings: Dict[str, Any] fastapp = FastAPI(title="GhostAI Music Server", version="1.0") fastapp.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @fastapp.get("/health") def health(): return {"ok": True, "ts": int(time.time())} @fastapp.get("/status") def status(): busy = is_busy() return { "busy": busy, "job_id": CURRENT_JOB["id"], "since": CURRENT_JOB["start"], "elapsed": job_elapsed(), "lockfile": os.path.exists(BUSY_FILE) } @fastapp.get("/config") def get_config(): return {"defaults": CURRENT_SETTINGS} @fastapp.post("/settings") def set_settings(payload: SettingsUpdate): try: s = CURRENT_SETTINGS.copy() s.update(payload.settings or {}) save_settings_to_file(s) for k, v in s.items(): CURRENT_SETTINGS[k] = v return {"ok": True, "saved": s} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @fastapp.post("/render") def render(req: RenderRequest): if is_busy(): raise HTTPException(status_code=409, detail="Server busy") job_id = f"render_{int(time.time())}" set_busy(True, job_id) try: s = CURRENT_SETTINGS.copy() # apply overrides for k, v in req.dict().items(): if v is not None: s[k] = v mp3, msg, vram = generate_music( s.get("instrumental_prompt", req.instrumental_prompt), float(s.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])), int(s.get("top_k", DEFAULT_SETTINGS["top_k"])), float(s.get("top_p", DEFAULT_SETTINGS["top_p"])), float(s.get("temperature", DEFAULT_SETTINGS["temperature"])), int(s.get("total_duration", DEFAULT_SETTINGS["total_duration"])), int(s.get("bpm", DEFAULT_SETTINGS["bpm"])), str(s.get("drum_beat", DEFAULT_SETTINGS["drum_beat"])), str(s.get("synthesizer", DEFAULT_SETTINGS["synthesizer"])), str(s.get("rhythmic_steps", DEFAULT_SETTINGS["rhythmic_steps"])), str(s.get("bass_style", DEFAULT_SETTINGS["bass_style"])), str(s.get("guitar_style", DEFAULT_SETTINGS["guitar_style"])), float(s.get("target_volume", DEFAULT_SETTINGS["target_volume"])), str(s.get("preset", DEFAULT_SETTINGS["preset"])), str(s.get("max_steps", DEFAULT_SETTINGS["max_steps"])), "", str(s.get("bitrate", DEFAULT_SETTINGS["bitrate"])), str(s.get("output_sample_rate", DEFAULT_SETTINGS["output_sample_rate"])), str(s.get("bit_depth", DEFAULT_SETTINGS["bit_depth"])) ) if not mp3: raise HTTPException(status_code=500, detail=msg) return {"ok": True, "job_id": job_id, "path": mp3, "status": msg, "vram": vram} finally: set_busy(False, None) def _start_fastapi(): uvicorn.run(fastapp, host="0.0.0.0", port=8555, log_level="info") api_thread = threading.Thread(target=_start_fastapi, daemon=True) api_thread.start() logger.info("FastAPI server started on http://0.0.0.0:8555") # ====================================================================================== # GRADIO UI (HIGH CONTRAST / WHITE TEXT) # ====================================================================================== CSS = """ :root { color-scheme: dark; } body, .gradio-container, .block, .tabs, .panel, .form, .wrap { background: #0B0B0D !important; color: #FFFFFF !important; } * { color: #FFFFFF !important; } label, p, span, h1, h2, h3, h4, h5, h6 { color: #FFFFFF !important; } input, textarea, select { background: #15151A !important; color: #FFFFFF !important; border: 1px solid #2B2B33 !important; } button { background: #1F6FEB !important; color: #FFFFFF !important; border: 2px solid transparent !important; border-radius: 8px !important; padding: 10px 16px !important; font-weight: 700 !important; } button:hover { background: #2D7BFF !important; } button:focus { outline: 3px solid #00C853 !important; } .slider > input { accent-color: #FFD600 !important; } .group-container { border: 1px solid #2B2B33; border-radius: 10px; padding: 16px; } .header { text-align:center; padding: 12px 16px; border-bottom: 2px solid #00C853; } .header h1 { font-size: 28px; margin: 6px 0 0 0; } .header .logo { font-size: 44px; } """ loaded = CURRENT_SETTINGS logger.info("Building Gradio UI...") with gr.Blocks(css=CSS, analytics_enabled=False, title="GhostAI Music Generator") as demo: gr.Markdown(f""" """) with gr.Column(elem_classes="input-container"): gr.Markdown("### Prompt") instrumental_prompt = gr.Textbox( label="Instrumental Prompt", placeholder="Type your instrumental prompt or use genre buttons below", lines=4, value=loaded.get("instrumental_prompt", ""), ) with gr.Row(): rhcp_btn = gr.Button("Red Hot Chili Peppers 🌶️") nirvana_btn = gr.Button("Nirvana 🎸") pearl_jam_btn = gr.Button("Pearl Jam 🦪") soundgarden_btn = gr.Button("Soundgarden 🌑") foo_fighters_btn = gr.Button("Foo Fighters 🤘") with gr.Row(): smashing_pumpkins_btn = gr.Button("Smashing Pumpkins 🎃") radiohead_btn = gr.Button("Radiohead 🧠") classic_rock_btn = gr.Button("Metallica Heavy Metal 🎸") alternative_rock_btn = gr.Button("Alternative Rock 🎵") post_punk_btn = gr.Button("Post-Punk 🖤") with gr.Row(): indie_rock_btn = gr.Button("Indie Rock 🎤") funk_rock_btn = gr.Button("Funk Rock 🕺") detroit_techno_btn = gr.Button("Detroit Techno 🎛️") deep_house_btn = gr.Button("Deep House 🏠") with gr.Column(elem_classes="settings-container"): gr.Markdown("### Settings") with gr.Group(elem_classes="group-container"): cfg_scale = gr.Slider(1.0, 10.0, step=0.1, value=float(loaded.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])), label="CFG Scale") top_k = gr.Slider(10, 500, step=10, value=int(loaded.get("top_k", DEFAULT_SETTINGS["top_k"])), label="Top-K") top_p = gr.Slider(0.0, 1.0, step=0.01, value=float(loaded.get("top_p", DEFAULT_SETTINGS["top_p"])), label="Top-P") temperature = gr.Slider(0.1, 2.0, step=0.01, value=float(loaded.get("temperature", DEFAULT_SETTINGS["temperature"])), label="Temperature") total_duration = gr.Dropdown(choices=[30, 60, 90, 120], value=int(loaded.get("total_duration", 60)), label="Song Length (seconds)") bpm = gr.Slider(60, 180, step=1, value=int(loaded.get("bpm", 120)), label="Tempo (BPM)") drum_beat = gr.Dropdown(choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"], value=str(loaded.get("drum_beat", "none")), label="Drum Beat") synthesizer = gr.Dropdown(choices=["none", "analog synth", "digital pad", "arpeggiated synth"], value=str(loaded.get("synthesizer", "none")), label="Synthesizer") rhythmic_steps = gr.Dropdown(choices=["none", "syncopated steps", "steady steps", "complex steps"], value=str(loaded.get("rhythmic_steps", "none")), label="Rhythmic Steps") bass_style = gr.Dropdown(choices=["none", "slap bass", "deep bass", "melodic bass"], value=str(loaded.get("bass_style", "none")), label="Bass Style") guitar_style = gr.Dropdown(choices=["none", "distorted", "clean", "jangle"], value=str(loaded.get("guitar_style", "none")), label="Guitar Style") target_volume = gr.Slider(-30.0, -20.0, step=0.5, value=float(loaded.get("target_volume", -23.0)), label="Target Loudness (dBFS RMS)") preset = gr.Dropdown(choices=["default", "rock", "techno", "grunge", "indie", "funk_rock"], value=str(loaded.get("preset", "default")), label="Preset") max_steps = gr.Dropdown(choices=[1000, 1200, 1300, 1500], value=int(loaded.get("max_steps", 1500)), label="Max Steps (per chunk hint)") bitrate_state = gr.State(value=str(loaded.get("bitrate", "192k"))) sample_rate_state = gr.State(value=str(loaded.get("output_sample_rate", "48000"))) bit_depth_state = gr.State(value=str(loaded.get("bit_depth", "16"))) with gr.Row(): bitrate_128_btn = gr.Button("Bitrate 128k") bitrate_192_btn = gr.Button("Bitrate 192k") bitrate_320_btn = gr.Button("Bitrate 320k") with gr.Row(): sample_rate_22050_btn = gr.Button("SR 22.05k") sample_rate_44100_btn = gr.Button("SR 44.1k") sample_rate_48000_btn = gr.Button("SR 48k") with gr.Row(): bit_depth_16_btn = gr.Button("16-bit") bit_depth_24_btn = gr.Button("24-bit") with gr.Row(): gen_btn = gr.Button("Generate Music 🚀") clr_btn = gr.Button("Clear 🧹") save_btn = gr.Button("Save Settings 💾") load_btn = gr.Button("Load Settings 📂") reset_btn = gr.Button("Reset Defaults ♻️") with gr.Column(elem_classes="output-container"): gr.Markdown("### Output") out_audio = gr.Audio(label="Generated Track", type="filepath") status_box = gr.Textbox(label="Status", interactive=False) vram_box = gr.Textbox(label="VRAM Usage", interactive=False, value="") with gr.Column(elem_classes="logs-container"): gr.Markdown("### Logs") log_output = gr.Textbox(label="Last Log File", lines=16, interactive=False) log_btn = gr.Button("View Last Log") # Genre buttons -> prompt text (chunk_num fixed = 1 for initial suggestion) rhcp_btn.click( set_red_hot_chili_peppers_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, gr.State(value=1)], outputs=instrumental_prompt ) nirvana_btn.click(set_nirvana_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) pearl_jam_btn.click(set_pearl_jam_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) soundgarden_btn.click(set_soundgarden_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) foo_fighters_btn.click(set_foo_fighters_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) radiohead_btn.click(set_radiohead_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) classic_rock_btn.click(set_classic_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) alternative_rock_btn.click(set_alternative_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) post_punk_btn.click(set_post_punk_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) indie_rock_btn.click(set_indie_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) funk_rock_btn.click(set_funk_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) detroit_techno_btn.click(set_detroit_techno_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) deep_house_btn.click(set_deep_house_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) # Bitrate / SR / Bit depth quick-sets bitrate_128_btn.click(set_bitrate_128, outputs=bitrate_state) bitrate_192_btn.click(set_bitrate_192, outputs=bitrate_state) bitrate_320_btn.click(set_bitrate_320, outputs=bitrate_state) sample_rate_22050_btn.click(set_sample_rate_22050, outputs=sample_rate_state) sample_rate_44100_btn.click(set_sample_rate_44100, outputs=sample_rate_state) sample_rate_48000_btn.click(set_sample_rate_48000, outputs=sample_rate_state) bit_depth_16_btn.click(set_bit_depth_16, outputs=bit_depth_state) bit_depth_24_btn.click(set_bit_depth_24, outputs=bit_depth_state) # Generate gen_btn.click( generate_music_wrapper, inputs=[ instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, vram_box, bitrate_state, sample_rate_state, bit_depth_state ], outputs=[out_audio, status_box, vram_box] ) # Clear clr_btn.click( clear_inputs, outputs=[ instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state ] ) # Save / Load / Reset actions def _save_action( instrumental_prompt_v, cfg_v, top_k_v, top_p_v, temp_v, dur_v, bpm_v, drum_v, synth_v, steps_v, bass_v, guitar_v, vol_v, preset_v, maxsteps_v, br_v, sr_v, bd_v ): s = { "instrumental_prompt": instrumental_prompt_v, "cfg_scale": float(cfg_v), "top_k": int(top_k_v), "top_p": float(top_p_v), "temperature": float(temp_v), "total_duration": int(dur_v), "bpm": int(bpm_v), "drum_beat": str(drum_v), "synthesizer": str(synth_v), "rhythmic_steps": str(steps_v), "bass_style": str(bass_v), "guitar_style": str(guitar_v), "target_volume": float(vol_v), "preset": str(preset_v), "max_steps": int(maxsteps_v), "bitrate": str(br_v), "output_sample_rate": str(sr_v), "bit_depth": str(bd_v) } save_settings_to_file(s) for k, v in s.items(): CURRENT_SETTINGS[k] = v return "✅ Settings saved." def _load_action(): s = load_settings_from_file() for k, v in s.items(): CURRENT_SETTINGS[k] = v return ( s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], s["bitrate"], s["output_sample_rate"], s["bit_depth"], "✅ Settings loaded." ) def _reset_action(): s = DEFAULT_SETTINGS.copy() save_settings_to_file(s) for k, v in s.items(): CURRENT_SETTINGS[k] = v return ( s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], s["bitrate"], s["output_sample_rate"], s["bit_depth"], "✅ Defaults restored." ) save_btn.click( _save_action, inputs=[ instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state ], outputs=status_box ) load_btn.click( _load_action, outputs=[ instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box ] ) reset_btn.click( _reset_action, outputs=[ instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box ] ) # Logs log_btn.click(get_latest_log, outputs=log_output) # ====================================================================================== # LAUNCH GRADIO # ====================================================================================== logger.info("Launching Gradio UI at http://0.0.0.0:9999 ...") try: demo.launch( server_name="0.0.0.0", server_port=9999, share=False, inbrowser=False, show_error=True ) except Exception as e: logger.error(f"Failed to launch Gradio UI: {e}") logger.error(traceback.format_exc()) sys.exit(1)