GHOSTSONAFB / stable12gblg30sec.py
ghostai1's picture
Create stable12gblg30sec.py
aaae0c3 verified
raw
history blame
57.7 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import gc
import re
import json
import time
import mmap
import math
import torch
import random
import logging
import warnings
import traceback
import subprocess
import numpy as np
import torchaudio
import gradio as gr
import gradio_client.utils
from pydub import AudioSegment
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple, Dict, Any, List
from torch.cuda.amp import autocast
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import threading
# ======================================================================================
# PATCHES & RUNTIME SETUP
# ======================================================================================
# Gradio schema bool patch (prevents crash for boolean schemas)
_original_get_type = gradio_client.utils.get_type
def _patched_get_type(schema):
if isinstance(schema, bool):
return "boolean"
return _original_get_type(schema)
gradio_client.utils.get_type = _patched_get_type
# Warnings
warnings.filterwarnings("ignore")
# Allocator for CUDA 12.x
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Determinism/Benchmark settings
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# Logging
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)
LOG_FILE = os.path.join(LOG_DIR, f"musicgen_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.FileHandler(LOG_FILE), logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("ghostai-musicgen")
# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE != "cuda":
logger.error("CUDA is required. Exiting.")
sys.exit(1)
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info("Precision: fp16 model, fp32 CPU audio ops")
# ======================================================================================
# SETTINGS PERSISTENCE
# ======================================================================================
SETTINGS_FILE = "settings.json"
DEFAULT_SETTINGS: Dict[str, Any] = {
"cfg_scale": 5.8,
"top_k": 250, # more creative search space
"top_p": 0.95, # user requested higher probability cap
"temperature": 0.90, # user requested ~0.9
"total_duration": 60, # default to 1 minute
"bpm": 120,
"drum_beat": "none",
"synthesizer": "none",
"rhythmic_steps": "none",
"bass_style": "none",
"guitar_style": "none",
"target_volume": -23.0,
"preset": "default",
"max_steps": 1500, # keep for UI, chunking now fixed to 30s
"bitrate": "192k",
"output_sample_rate": "48000",
"bit_depth": "16",
"instrumental_prompt": ""
}
def load_settings_from_file() -> Dict[str, Any]:
try:
if os.path.exists(SETTINGS_FILE):
with open(SETTINGS_FILE, "r") as f:
data = json.load(f)
# ensure all defaults present
for k, v in DEFAULT_SETTINGS.items():
data.setdefault(k, v)
logger.info(f"Loaded settings from {SETTINGS_FILE}")
return data
except Exception as e:
logger.error(f"Failed reading {SETTINGS_FILE}: {e}")
return DEFAULT_SETTINGS.copy()
def save_settings_to_file(settings: Dict[str, Any]) -> None:
try:
with open(SETTINGS_FILE, "w") as f:
json.dump(settings, f, indent=2)
logger.info(f"Saved settings to {SETTINGS_FILE}")
except Exception as e:
logger.error(f"Failed saving {SETTINGS_FILE}: {e}")
CURRENT_SETTINGS = load_settings_from_file()
# ======================================================================================
# VRAM / DISK / MEMORY
# ======================================================================================
def clean_memory() -> Optional[float]:
try:
torch.cuda.empty_cache()
gc.collect()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
vram_mb = torch.cuda.memory_allocated() / 1024**2
logger.info(f"Memory cleaned. VRAM={vram_mb:.2f} MB")
return vram_mb
except Exception as e:
logger.error(f"clean_memory failed: {e}")
logger.error(traceback.format_exc())
return None
def check_vram():
try:
r = subprocess.run(
['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv'],
capture_output=True, text=True
)
lines = r.stdout.splitlines()
if len(lines) > 1:
used_mb, total_mb = map(int, re.findall(r'\d+', lines[1]))
free_mb = total_mb - used_mb
logger.info(f"VRAM: used {used_mb} MiB | free {free_mb} MiB | total {total_mb} MiB")
if free_mb < 5000:
logger.warning(f"Low free VRAM ({free_mb} MiB). Running processes:")
procs = subprocess.run(
['nvidia-smi', '--query-compute-apps=pid,used_memory', '--format=csv'],
capture_output=True, text=True
)
logger.info(f"\n{procs.stdout}")
return free_mb
except Exception as e:
logger.error(f"check_vram failed: {e}")
return None
def check_disk_space(path=".") -> bool:
try:
stat = os.statvfs(path)
free_gb = stat.f_bavail * stat.f_frsize / (1024**3)
if free_gb < 1.0:
logger.warning(f"Low disk space: {free_gb:.2f} GB")
return free_gb >= 1.0
except Exception as e:
logger.error(f"Disk space check failed: {e}")
return False
# ======================================================================================
# AUDIO UTILS (CPU)
# ======================================================================================
def ensure_stereo(audio_segment: AudioSegment, sample_rate=48000, sample_width=2) -> AudioSegment:
try:
if audio_segment.channels != 2:
audio_segment = audio_segment.set_channels(2)
if audio_segment.frame_rate != sample_rate:
audio_segment = audio_segment.set_frame_rate(sample_rate)
return audio_segment
except Exception as e:
logger.error(f"ensure_stereo failed: {e}")
return audio_segment
def calculate_rms(segment: AudioSegment) -> float:
try:
samples = np.array(segment.get_array_of_samples(), dtype=np.float32)
rms = float(np.sqrt(np.mean(samples**2)))
return rms
except Exception as e:
logger.error(f"calculate_rms failed: {e}")
return 0.0
def hard_limit(audio_segment: AudioSegment, limit_db=-3.0, sample_rate=48000) -> AudioSegment:
try:
audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
limit = 10 ** (limit_db / 20.0) * (2**23 if audio_segment.sample_width == 3 else 32767)
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
samples = np.clip(samples, -limit, limit).astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
if len(samples) % 2 != 0:
samples = samples[:-1]
return AudioSegment(
samples.tobytes(),
frame_rate=sample_rate,
sample_width=audio_segment.sample_width,
channels=2
)
except Exception as e:
logger.error(f"hard_limit failed: {e}")
return audio_segment
def rms_normalize(segment: AudioSegment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=48000) -> AudioSegment:
try:
segment = ensure_stereo(segment, sample_rate, segment.sample_width)
target_rms = 10 ** (target_rms_db / 20) * (2**23 if segment.sample_width == 3 else 32767)
current_rms = calculate_rms(segment)
if current_rms > 0:
gain_factor = target_rms / current_rms
segment = segment.apply_gain(20 * np.log10(max(gain_factor, 1e-6)))
segment = hard_limit(segment, limit_db=peak_limit_db, sample_rate=sample_rate)
return segment
except Exception as e:
logger.error(f"rms_normalize failed: {e}")
return segment
def balance_stereo(audio_segment: AudioSegment, noise_threshold=-40, sample_rate=48000) -> AudioSegment:
try:
audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
if audio_segment.channels != 2:
return audio_segment
stereo = samples.reshape(-1, 2)
db = 20 * np.log10(np.abs(stereo) + 1e-10)
mask = db > noise_threshold
stereo = stereo * mask
left = stereo[:, 0]
right = stereo[:, 1]
l_rms = np.sqrt(np.mean(left[left != 0] ** 2)) if np.any(left != 0) else 0
r_rms = np.sqrt(np.mean(right[right != 0] ** 2)) if np.any(right != 0) else 0
if l_rms > 0 and r_rms > 0:
avg = (l_rms + r_rms) / 2
stereo[:, 0] *= (avg / l_rms)
stereo[:, 1] *= (avg / r_rms)
out = stereo.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
if len(out) % 2 != 0:
out = out[:-1]
return AudioSegment(
out.tobytes(),
frame_rate=sample_rate,
sample_width=audio_segment.sample_width,
channels=2
)
except Exception as e:
logger.error(f"balance_stereo failed: {e}")
return audio_segment
def apply_noise_gate(audio_segment: AudioSegment, threshold_db=-80, sample_rate=48000) -> AudioSegment:
try:
audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
if audio_segment.channels != 2:
return audio_segment
stereo = samples.reshape(-1, 2)
for _ in range(2):
db = 20 * np.log10(np.abs(stereo) + 1e-10)
mask = db > threshold_db
stereo = stereo * mask
out = stereo.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
if len(out) % 2 != 0:
out = out[:-1]
return AudioSegment(
out.tobytes(),
frame_rate=sample_rate,
sample_width=audio_segment.sample_width,
channels=2
)
except Exception as e:
logger.error(f"apply_noise_gate failed: {e}")
return audio_segment
def apply_eq(segment: AudioSegment, sample_rate=48000) -> AudioSegment:
try:
segment = ensure_stereo(segment, sample_rate, segment.sample_width)
segment = segment.high_pass_filter(20)
segment = segment.low_pass_filter(8000)
segment = segment - 3
segment = segment - 3
segment = segment - 10
return segment
except Exception as e:
logger.error(f"apply_eq failed: {e}")
return segment
def apply_fade(segment: AudioSegment, fade_in_duration=500, fade_out_duration=500) -> AudioSegment:
try:
segment = ensure_stereo(segment, segment.frame_rate, segment.sample_width)
segment = segment.fade_in(fade_in_duration).fade_out(fade_out_duration)
return segment
except Exception as e:
logger.error(f"apply_fade failed: {e}")
return segment
# ======================================================================================
# PROMPTS
# ======================================================================================
def set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num):
try:
bpm_range = (90, 130)
bpm = random.randint(*bpm_range) if bpm == 120 else bpm
drum = f", {drum_beat} drums" if drum_beat != "none" else ", standard rock drums with funk fills"
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style} bass" if bass_style != "none" else ", funky slap bass"
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", energetic guitar riffs"
base = f"Instrumental alternative rock by Red Hot Chili Peppers{guitar}{bass}{drum}{synth}, funk-rock energy at {bpm} BPM"
if chunk_num == 1:
return base + ", dynamic intro and expressive verse."
return base + ", powerful chorus and energetic outro."
except Exception:
return ""
def set_nirvana_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
bpm_range = (100, 130)
bpm = random.randint(*bpm_range) if bpm == 120 else bpm
drum = f", {drum_beat} drums, punk energy" if drum_beat != "none" else ", standard rock drums, punk energy"
synth = f", {synthesizer}" if synthesizer != "none" else ""
chosen_bass = random.choice(['deep bass', 'melodic bass']) if bass_style == "none" else bass_style
bass = f", {chosen_bass}"
chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style
guitar = f", {chosen_guitar}"
chosen_rhythm = random.choice(['steady steps', 'dynamic shifts']) if rhythmic_steps == "none" else rhythmic_steps
rhythm = f", {chosen_rhythm}"
return f"Instrumental grunge by Nirvana{guitar}{bass}{drum}{synth}, raw lo-fi production{rhythm} at {bpm} BPM."
except Exception:
return ""
def set_pearl_jam_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
bpm_range = (100, 140)
bpm = random.randint(*bpm_range) if bpm == 120 else bpm
drum = f", {drum_beat} drums, driving rhythm" if drum_beat != "none" else ", standard rock drums, driving rhythm"
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style}, emotional tone" if bass_style != "none" else ", melodic bass, emotional tone"
chosen_guitar = random.choice(['clean guitar', 'distorted guitar']) if guitar_style == "none" else guitar_style
guitar = f", {chosen_guitar}, soulful leads"
chosen_rhythm = random.choice(['steady steps', 'syncopated steps']) if rhythmic_steps == "none" else rhythmic_steps
rhythm = f", {chosen_rhythm}"
return f"Instrumental grunge by Pearl Jam{guitar}{bass}{drum}{synth}, classic rock influences{rhythm} at {bpm} BPM."
except Exception:
return ""
def set_soundgarden_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
bpm_range = (90, 140)
bpm = random.randint(*bpm_range) if bpm == 120 else bpm
drum = f", {drum_beat} drums, heavy rhythm" if drum_beat != "none" else ", standard rock drums, heavy rhythm"
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style}, sludgy tone" if bass_style != "none" else ", deep bass, sludgy tone"
guitar = f", {guitar_style}, downtuned riffs, psychedelic vibe" if guitar_style != "none" else ", distorted guitar, downtuned riffs, psychedelic vibe"
rhythm = f", {rhythmic_steps}" if rhythmic_steps != "none" else ", complex steps"
return f"Instrumental grunge with heavy metal influences by Soundgarden{guitar}{bass}{drum}{synth}{rhythm} at {bpm} BPM."
except Exception:
return ""
def set_foo_fighters_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
bpm_range = (110, 150)
bpm = random.randint(*bpm_range) if bpm == 120 else bpm
drum = f", {drum_beat} drums, powerful drive" if drum_beat != "none" else ", standard rock drums, powerful drive"
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style}, supportive tone" if bass_style != "none" else ", melodic bass, supportive tone"
chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style
guitar = f", {chosen_guitar}, anthemic quality"
chosen_rhythm = random.choice(['steady steps', 'driving rhythm']) if rhythmic_steps == "none" else rhythmic_steps
rhythm = f", {chosen_rhythm}"
return f"Instrumental alternative rock with post-grunge influences by Foo Fighters{guitar}, stadium-ready hooks{bass}{drum}{synth}{rhythm} at {bpm} BPM."
except Exception:
return ""
def set_classic_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
bpm_range = (120, 180)
bpm = random.randint(*bpm_range) if bpm == 120 else bpm
drum = f", {drum_beat} drums" if drum_beat != "none" else ", double bass drums"
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style}" if bass_style != "none" else ", aggressive bass"
guitar = f", {guitar_style}, blazing fast riffs" if guitar_style != "none" else ", distorted guitar, blazing fast riffs"
rhythm = f", {rhythmic_steps}" if rhythmic_steps != "none" else ", complex steps"
return f"Instrumental thrash metal by Metallica{guitar}{bass}{drum}{synth}{rhythm} at {bpm} BPM."
except Exception:
return ""
def set_smashing_pumpkins_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
drum = f", {drum_beat} drums" if drum_beat != "none" else ""
synth = f", {synthesizer}" if synthesizer != "none" else ", lush synths"
bass = f", {bass_style} bass" if bass_style != "none" else ""
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", dreamy guitar"
return f"Instrumental alternative rock by Smashing Pumpkins{guitar}{synth}{drum}{bass} at {bpm} BPM."
except Exception:
return ""
def set_radiohead_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
drum = f", {drum_beat} drums" if drum_beat != "none" else ""
synth = f", {synthesizer}" if synthesizer != "none" else ", atmospheric synths"
bass = f", {bass_style} bass" if bass_style != "none" else ", hypnotic bass"
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ""
return f"Instrumental experimental rock by Radiohead{synth}{bass}{drum}{guitar} at {bpm} BPM."
except Exception:
return ""
def set_alternative_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
drum = f", {drum_beat} drums" if drum_beat != "none" else ""
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style} bass" if bass_style != "none" else ", melodic bass"
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", distorted guitar"
return f"Instrumental alternative rock by Pixies{guitar}{bass}{drum}{synth} at {bpm} BPM."
except Exception:
return ""
def set_post_punk_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
drum = f", {drum_beat} drums" if drum_beat != "none" else ", precise drums"
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style} bass" if bass_style != "none" else ", driving bass"
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar"
return f"Instrumental post-punk by Joy Division{guitar}{bass}{drum}{synth} at {bpm} BPM."
except Exception:
return ""
def set_indie_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
drum = f", {drum_beat} drums" if drum_beat != "none" else ""
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style} bass" if bass_style != "none" else ", groovy bass"
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar"
return f"Instrumental indie rock by Arctic Monkeys{guitar}{bass}{drum}{synth} at {bpm} BPM."
except Exception:
return ""
def set_funk_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
drum = f", {drum_beat} drums" if drum_beat != "none" else ", heavy drums"
synth = f", {synthesizer}" if synthesizer != "none" else ""
bass = f", {bass_style} bass" if bass_style != "none" else ", slap bass"
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", funky guitar"
return f"Instrumental funk rock by Rage Against the Machine{guitar}{bass}{drum}{synth} at {bpm} BPM."
except Exception:
return ""
def set_detroit_techno_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
drum = f", {drum_beat} drums" if drum_beat != "none" else ", four-on-the-floor drums"
synth = f", {synthesizer}" if synthesizer != "none" else ", pulsing synths"
bass = f", {bass_style} bass" if bass_style != "none" else ", driving bass"
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ""
return f"Instrumental Detroit techno by Juan Atkins{synth}{bass}{drum}{guitar} at {bpm} BPM."
except Exception:
return ""
def set_deep_house_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
try:
drum = f", {drum_beat} drums" if drum_beat != "none" else ", steady kick drums"
synth = f", {synthesizer}" if synthesizer != "none" else ", warm synths"
bass = f", {bass_style} bass" if bass_style != "none" else ", deep bass"
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ""
return f"Instrumental deep house by Larry Heard{synth}{bass}{drum}{guitar} at {bpm} BPM."
except Exception:
return ""
PRESETS = {
"default": {"cfg_scale": 5.8, "top_k": 250, "top_p": 0.95, "temperature": 0.90},
"rock": {"cfg_scale": 5.8, "top_k": 250, "top_p": 0.95, "temperature": 0.90},
"techno": {"cfg_scale": 5.2, "top_k": 300, "top_p": 0.96, "temperature": 0.95},
"grunge": {"cfg_scale": 6.2, "top_k": 220, "top_p": 0.94, "temperature": 0.90},
"indie": {"cfg_scale": 5.5, "top_k": 240, "top_p": 0.95, "temperature": 0.92},
"funk_rock": {"cfg_scale": 5.8, "top_k": 260, "top_p": 0.96, "temperature": 0.94},
}
# ======================================================================================
# MODEL LOAD
# ======================================================================================
try:
from audiocraft.models import MusicGen
except Exception as e:
logger.error("audiocraft is required. pip install audiocraft")
raise
def load_model():
free_vram = check_vram()
if free_vram is not None and free_vram < 5000:
logger.warning("Low free VRAM; consider closing other apps.")
clean_memory()
local_model_path = "./models/musicgen-large"
if not os.path.exists(local_model_path):
logger.error(f"Model path missing: {local_model_path}")
sys.exit(1)
logger.info("Loading MusicGen (large)...")
with autocast(dtype=torch.float16):
model = MusicGen.get_pretrained(local_model_path, device=DEVICE)
# base params get overridden per-call
model.set_generation_params(duration=30, two_step_cfg=False)
logger.info("MusicGen loaded.")
return model
musicgen_model = load_model()
# ======================================================================================
# GENERATION PIPELINE (30s CHUNKING, SEAMLESS MERGE)
# ======================================================================================
def get_latest_log() -> str:
try:
files = sorted(Path(LOG_DIR).glob("musicgen_log_*.log"), key=os.path.getmtime, reverse=True)
if not files:
return "No log files found."
return files[0].read_text()
except Exception as e:
return f"Error reading log: {e}"
def set_bitrate_128(): return "128k"
def set_bitrate_192(): return "192k"
def set_bitrate_320(): return "320k"
def set_sample_rate_22050(): return "22050"
def set_sample_rate_44100(): return "44100"
def set_sample_rate_48000(): return "48000"
def set_bit_depth_16(): return "16"
def set_bit_depth_24(): return "24"
def generate_music_wrapper(*args):
try:
return generate_music(*args)
finally:
clean_memory()
def _export_torch_to_segment(audio_tensor: torch.Tensor, sample_rate: int, bit_depth_int: int) -> Optional[AudioSegment]:
"""Helper: save torch stereo float32 to WAV -> load with pydub as segment."""
temp = f"temp_audio_{int(time.time()*1000)}.wav"
try:
torchaudio.save(temp, audio_tensor, sample_rate, bits_per_sample=bit_depth_int)
with open(temp, "rb") as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
seg = AudioSegment.from_wav(temp)
mm.close()
return seg
except Exception as e:
logger.error(f"_export_torch_to_segment failed: {e}")
logger.error(traceback.format_exc())
return None
finally:
try:
if os.path.exists(temp):
os.remove(temp)
except OSError:
pass
def _crossfade_segments(seg_a: AudioSegment, seg_b: AudioSegment, overlap_ms: int, sample_rate: int, bit_depth_int: int) -> AudioSegment:
"""Blend tail of seg_a with head of seg_b using hann window for seamless merge."""
try:
seg_a = ensure_stereo(seg_a, sample_rate, seg_a.sample_width)
seg_b = ensure_stereo(seg_b, sample_rate, seg_b.sample_width)
if overlap_ms <= 0 or len(seg_a) < overlap_ms or len(seg_b) < overlap_ms:
return seg_a + seg_b
# export overlaps
prev_wav = f"tmp_prev_{int(time.time()*1000)}.wav"
curr_wav = f"tmp_curr_{int(time.time()*1000)}.wav"
try:
seg_a[-overlap_ms:].export(prev_wav, format="wav")
seg_b[:overlap_ms].export(curr_wav, format="wav")
a_audio, sr_a = torchaudio.load(prev_wav)
b_audio, sr_b = torchaudio.load(curr_wav)
if sr_a != sample_rate:
a_audio = torchaudio.functional.resample(a_audio, sr_a, sample_rate, lowpass_filter_width=64)
if sr_b != sample_rate:
b_audio = torchaudio.functional.resample(b_audio, sr_b, sample_rate, lowpass_filter_width=64)
n = min(a_audio.shape[1], b_audio.shape[1])
n = n - (n % 2)
if n <= 0:
return seg_a + seg_b
a = a_audio[:, :n]
b = b_audio[:, :n]
hann = torch.hann_window(n, periodic=False)
fade_in = hann
fade_out = hann.flip(0)
blended = (a * fade_out + b * fade_in).to(torch.float32)
blended = torch.clamp(blended, -1.0, 1.0)
# scale to PCM and save
scale = (2**23 if bit_depth_int == 24 else 32767)
blended_i = (blended * scale).to(torch.int32 if bit_depth_int == 24 else torch.int16)
temp_x = f"tmp_cross_{int(time.time()*1000)}.wav"
torchaudio.save(temp_x, blended_i, sample_rate, bits_per_sample=bit_depth_int)
blended_seg = AudioSegment.from_wav(temp_x)
blended_seg = ensure_stereo(blended_seg, sample_rate, blended_seg.sample_width)
# combine
result = seg_a[:-overlap_ms] + blended_seg + seg_b[overlap_ms:]
try:
if os.path.exists(temp_x):
os.remove(temp_x)
except OSError:
pass
return result
finally:
for p in [prev_wav, curr_wav]:
try:
if os.path.exists(p):
os.remove(p)
except OSError:
pass
except Exception as e:
logger.error(f"_crossfade_segments failed: {e}")
return seg_a + seg_b
def generate_music(
instrumental_prompt: str,
cfg_scale: float,
top_k: int,
top_p: float,
temperature: float,
total_duration: int,
bpm: int,
drum_beat: str,
synthesizer: str,
rhythmic_steps: str,
bass_style: str,
guitar_style: str,
target_volume: float,
preset: str,
max_steps: str, # kept for UI parity
vram_status_text: str,
bitrate: str,
output_sample_rate: str,
bit_depth: str
) -> Tuple[Optional[str], str, str]:
global musicgen_model
if not instrumental_prompt or not instrumental_prompt.strip():
return None, "⚠️ Please enter a valid instrumental prompt!", vram_status_text
try:
# Apply preset if not default
if preset != "default":
p = PRESETS.get(preset, PRESETS["default"])
cfg_scale, top_k, top_p, temperature = p["cfg_scale"], p["top_k"], p["top_p"], p["temperature"]
logger.info(f"Preset '{preset}' applied: cfg={cfg_scale} top_k={top_k} top_p={top_p} temp={temperature}")
# Validate numerics
try:
output_sr_int = int(output_sample_rate)
except:
return None, "❌ Invalid output sampling rate; choose 22050/44100/48000", vram_status_text
try:
bit_depth_int = int(bit_depth)
sample_width = 3 if bit_depth_int == 24 else 2
except:
return None, "❌ Invalid bit depth; choose 16 or 24", vram_status_text
if not check_disk_space():
return None, "⚠️ Low disk space (<1GB).", vram_status_text
# Chunking: EXACT 30s per chunk (unify stepping -> always 30s). Two chunks => full 60s song.
CHUNK_SEC = 30
total_duration = max(30, min(int(total_duration), 120))
num_chunks = math.ceil(total_duration / CHUNK_SEC)
# Internal processing rate (resample to this for DSP)
PROCESS_SR = 48000
OVERLAP_SEC = 0.20 # 200ms crossfade/prompt tail
channels = 2
# Seed & params
seed = random.randint(0, 2**31 - 1)
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
musicgen_model.set_generation_params(
duration=CHUNK_SEC,
use_sampling=True,
top_k=int(top_k),
top_p=float(top_p),
temperature=float(temperature),
cfg_coef=float(cfg_scale),
two_step_cfg=False,
)
vram_status_text = f"Start VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
segments: List[AudioSegment] = []
start_time = time.time()
for idx in range(num_chunks):
chunk_idx = idx + 1
dur = CHUNK_SEC if (idx < num_chunks - 1) else (total_duration - CHUNK_SEC * (num_chunks - 1) or CHUNK_SEC)
logger.info(f"Generating chunk {chunk_idx}/{num_chunks} ({dur}s)")
# Prompt per chunk (variable for RHCP only)
if "Red Hot Chili Peppers" in instrumental_prompt:
prompt_text = set_red_hot_chili_peppers_prompt(
bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_idx
)
else:
prompt_text = instrumental_prompt
try:
with torch.no_grad():
with autocast(dtype=torch.float16):
clean_memory()
if idx == 0:
audio = musicgen_model.generate([prompt_text], progress=True)[0].cpu()
else:
# Use tail of previous segment as continuation prompt
prev_seg = segments[-1]
prev_seg = apply_noise_gate(prev_seg, threshold_db=-80, sample_rate=PROCESS_SR)
prev_seg = balance_stereo(prev_seg, noise_threshold=-40, sample_rate=PROCESS_SR)
temp_prev = f"prev_{int(time.time()*1000)}.wav"
try:
prev_seg.export(temp_prev, format="wav")
prev_audio, prev_sr = torchaudio.load(temp_prev)
if prev_sr != PROCESS_SR:
prev_audio = torchaudio.functional.resample(prev_audio, prev_sr, PROCESS_SR, lowpass_filter_width=64)
if prev_audio.shape[0] != 2:
prev_audio = prev_audio.repeat(2, 1)[:, :prev_audio.shape[1]]
prev_audio = prev_audio.to(DEVICE)
tail = prev_audio[:, -int(PROCESS_SR * OVERLAP_SEC):]
audio = musicgen_model.generate_continuation(
prompt=tail,
prompt_sample_rate=PROCESS_SR,
descriptions=[prompt_text],
progress=True
)[0].cpu()
del prev_audio, tail
finally:
try:
if os.path.exists(temp_prev):
os.remove(temp_prev)
except OSError:
pass
clean_memory()
except Exception as e:
logger.error(f"Chunk {chunk_idx} generation failed: {e}")
logger.error(traceback.format_exc())
return None, f"❌ Failed to generate chunk {chunk_idx}: {e}", vram_status_text
try:
# Ensure stereo & resample to PROCESS_SR for DSP
if audio.shape[0] != 2:
audio = audio.repeat(2, 1)[:, :audio.shape[1]]
audio = audio.to(dtype=torch.float32)
audio = torchaudio.functional.resample(audio, 32000, PROCESS_SR, lowpass_filter_width=64)
seg = _export_torch_to_segment(audio, PROCESS_SR, bit_depth_int)
if seg is None:
return None, f"❌ Failed to convert audio for chunk {chunk_idx}", vram_status_text
seg = ensure_stereo(seg, PROCESS_SR, sample_width)
seg = seg - 15
seg = apply_noise_gate(seg, threshold_db=-80, sample_rate=PROCESS_SR)
seg = balance_stereo(seg, noise_threshold=-40, sample_rate=PROCESS_SR)
seg = rms_normalize(seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR)
seg = apply_eq(seg, sample_rate=PROCESS_SR)
# Trim exactly to 'dur' seconds for last chunk
seg = seg[:dur * 1000]
segments.append(seg)
del audio
clean_memory()
vram_status_text = f"VRAM after chunk {chunk_idx}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
except Exception as e:
logger.error(f"Post-processing failed (chunk {chunk_idx}): {e}")
logger.error(traceback.format_exc())
return None, f"❌ Failed to process chunk {chunk_idx}: {e}", vram_status_text
if not segments:
return None, "❌ No audio generated.", vram_status_text
# Seamless join with crossfades
logger.info("Combining chunks...")
final_seg = segments[0]
overlap_ms = int(OVERLAP_SEC * 1000)
for i in range(1, len(segments)):
final_seg = _crossfade_segments(final_seg, segments[i], overlap_ms, PROCESS_SR, bit_depth_int)
# Final length clamp
final_seg = final_seg[:total_duration * 1000]
# Final polish
final_seg = apply_noise_gate(final_seg, threshold_db=-80, sample_rate=PROCESS_SR)
final_seg = balance_stereo(final_seg, noise_threshold=-40, sample_rate=PROCESS_SR)
final_seg = rms_normalize(final_seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR)
final_seg = apply_eq(final_seg, sample_rate=PROCESS_SR)
final_seg = apply_fade(final_seg, 500, 800)
final_seg = final_seg - 10
final_seg = final_seg.set_frame_rate(output_sr_int)
# Export MP3
mp3_path = f"ghostai_music_{int(time.time())}.mp3"
try:
clean_memory()
final_seg.export(mp3_path, format="mp3", bitrate=bitrate, tags={"title": "GhostAI Instrumental", "artist": "GhostAI"})
except Exception as e:
logger.error(f"MP3 export failed ({bitrate}): {e}")
fb = f"ghostai_music_fallback_{int(time.time())}.mp3"
try:
final_seg.export(fb, format="mp3", bitrate="128k")
mp3_path = fb
except Exception as ee:
return None, f"❌ Failed to export MP3: {ee}", vram_status_text
elapsed = time.time() - start_time
vram_status_text = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
logger.info(f"Done in {elapsed:.2f}s -> {mp3_path}")
return mp3_path, "✅ Done! 30s chunking unified seamlessly. Check output loudness/quality.", vram_status_text
except Exception as e:
logger.error(f"Generation failed: {e}")
logger.error(traceback.format_exc())
return None, f"❌ Generation failed: {e}", vram_status_text
finally:
clean_memory()
def clear_inputs():
s = DEFAULT_SETTINGS.copy()
return (
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"],
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"],
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"],
s["bitrate"], s["output_sample_rate"], s["bit_depth"]
)
# ======================================================================================
# SERVER STATUS (BUSY/IDLE) & RENDER API
# ======================================================================================
BUSY_LOCK = threading.Lock()
BUSY_FLAG = False
BUSY_FILE = "/tmp/musicgen_busy.lock"
CURRENT_JOB: Dict[str, Any] = {"id": None, "start": None}
def set_busy(val: bool, job_id: Optional[str] = None):
global BUSY_FLAG, CURRENT_JOB
with BUSY_LOCK:
BUSY_FLAG = val
if val:
CURRENT_JOB["id"] = job_id or f"job_{int(time.time())}"
CURRENT_JOB["start"] = time.time()
try:
Path(BUSY_FILE).write_text(CURRENT_JOB["id"])
except Exception:
pass
else:
CURRENT_JOB["id"] = None
CURRENT_JOB["start"] = None
try:
if os.path.exists(BUSY_FILE):
os.remove(BUSY_FILE)
except Exception:
pass
def is_busy() -> bool:
with BUSY_LOCK:
return BUSY_FLAG
def job_elapsed() -> float:
with BUSY_LOCK:
if CURRENT_JOB["start"] is None:
return 0.0
return time.time() - CURRENT_JOB["start"]
class RenderRequest(BaseModel):
instrumental_prompt: str
cfg_scale: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
total_duration: Optional[int] = None
bpm: Optional[int] = None
drum_beat: Optional[str] = None
synthesizer: Optional[str] = None
rhythmic_steps: Optional[str] = None
bass_style: Optional[str] = None
guitar_style: Optional[str] = None
target_volume: Optional[float] = None
preset: Optional[str] = None
max_steps: Optional[int] = None
bitrate: Optional[str] = None
output_sample_rate: Optional[str] = None
bit_depth: Optional[str] = None
class SettingsUpdate(BaseModel):
settings: Dict[str, Any]
fastapp = FastAPI(title="GhostAI Music Server", version="1.0")
fastapp.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
@fastapp.get("/health")
def health():
return {"ok": True, "ts": int(time.time())}
@fastapp.get("/status")
def status():
busy = is_busy()
return {
"busy": busy,
"job_id": CURRENT_JOB["id"],
"since": CURRENT_JOB["start"],
"elapsed": job_elapsed(),
"lockfile": os.path.exists(BUSY_FILE)
}
@fastapp.get("/config")
def get_config():
return {"defaults": CURRENT_SETTINGS}
@fastapp.post("/settings")
def set_settings(payload: SettingsUpdate):
try:
s = CURRENT_SETTINGS.copy()
s.update(payload.settings or {})
save_settings_to_file(s)
for k, v in s.items():
CURRENT_SETTINGS[k] = v
return {"ok": True, "saved": s}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@fastapp.post("/render")
def render(req: RenderRequest):
if is_busy():
raise HTTPException(status_code=409, detail="Server busy")
job_id = f"render_{int(time.time())}"
set_busy(True, job_id)
try:
s = CURRENT_SETTINGS.copy()
# apply overrides
for k, v in req.dict().items():
if v is not None:
s[k] = v
mp3, msg, vram = generate_music(
s.get("instrumental_prompt", req.instrumental_prompt),
float(s.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])),
int(s.get("top_k", DEFAULT_SETTINGS["top_k"])),
float(s.get("top_p", DEFAULT_SETTINGS["top_p"])),
float(s.get("temperature", DEFAULT_SETTINGS["temperature"])),
int(s.get("total_duration", DEFAULT_SETTINGS["total_duration"])),
int(s.get("bpm", DEFAULT_SETTINGS["bpm"])),
str(s.get("drum_beat", DEFAULT_SETTINGS["drum_beat"])),
str(s.get("synthesizer", DEFAULT_SETTINGS["synthesizer"])),
str(s.get("rhythmic_steps", DEFAULT_SETTINGS["rhythmic_steps"])),
str(s.get("bass_style", DEFAULT_SETTINGS["bass_style"])),
str(s.get("guitar_style", DEFAULT_SETTINGS["guitar_style"])),
float(s.get("target_volume", DEFAULT_SETTINGS["target_volume"])),
str(s.get("preset", DEFAULT_SETTINGS["preset"])),
str(s.get("max_steps", DEFAULT_SETTINGS["max_steps"])),
"",
str(s.get("bitrate", DEFAULT_SETTINGS["bitrate"])),
str(s.get("output_sample_rate", DEFAULT_SETTINGS["output_sample_rate"])),
str(s.get("bit_depth", DEFAULT_SETTINGS["bit_depth"]))
)
if not mp3:
raise HTTPException(status_code=500, detail=msg)
return {"ok": True, "job_id": job_id, "path": mp3, "status": msg, "vram": vram}
finally:
set_busy(False, None)
def _start_fastapi():
uvicorn.run(fastapp, host="0.0.0.0", port=8555, log_level="info")
api_thread = threading.Thread(target=_start_fastapi, daemon=True)
api_thread.start()
logger.info("FastAPI server started on http://0.0.0.0:8555")
# ======================================================================================
# GRADIO UI (HIGH CONTRAST / WHITE TEXT)
# ======================================================================================
CSS = """
:root { color-scheme: dark; }
body, .gradio-container, .block, .tabs, .panel, .form, .wrap { background: #0B0B0D !important; color: #FFFFFF !important; }
* { color: #FFFFFF !important; }
label, p, span, h1, h2, h3, h4, h5, h6 { color: #FFFFFF !important; }
input, textarea, select { background: #15151A !important; color: #FFFFFF !important; border: 1px solid #2B2B33 !important; }
button { background: #1F6FEB !important; color: #FFFFFF !important; border: 2px solid transparent !important; border-radius: 8px !important; padding: 10px 16px !important; font-weight: 700 !important; }
button:hover { background: #2D7BFF !important; }
button:focus { outline: 3px solid #00C853 !important; }
.slider > input { accent-color: #FFD600 !important; }
.group-container { border: 1px solid #2B2B33; border-radius: 10px; padding: 16px; }
.header { text-align:center; padding: 12px 16px; border-bottom: 2px solid #00C853; }
.header h1 { font-size: 28px; margin: 6px 0 0 0; }
.header .logo { font-size: 44px; }
"""
loaded = CURRENT_SETTINGS
logger.info("Building Gradio UI...")
with gr.Blocks(css=CSS, analytics_enabled=False, title="GhostAI Music Generator") as demo:
gr.Markdown(f"""
<div class="header" role="banner" aria-label="GhostAI Music Generator">
<div class="logo">👻</div>
<h1>GhostAI Music Generator</h1>
<p>30s chunking, seamless joins, 1-minute ready, API status & settings save</p>
</div>
""")
with gr.Column(elem_classes="input-container"):
gr.Markdown("### Prompt")
instrumental_prompt = gr.Textbox(
label="Instrumental Prompt",
placeholder="Type your instrumental prompt or use genre buttons below",
lines=4,
value=loaded.get("instrumental_prompt", ""),
)
with gr.Row():
rhcp_btn = gr.Button("Red Hot Chili Peppers 🌶️")
nirvana_btn = gr.Button("Nirvana 🎸")
pearl_jam_btn = gr.Button("Pearl Jam 🦪")
soundgarden_btn = gr.Button("Soundgarden 🌑")
foo_fighters_btn = gr.Button("Foo Fighters 🤘")
with gr.Row():
smashing_pumpkins_btn = gr.Button("Smashing Pumpkins 🎃")
radiohead_btn = gr.Button("Radiohead 🧠")
classic_rock_btn = gr.Button("Metallica Heavy Metal 🎸")
alternative_rock_btn = gr.Button("Alternative Rock 🎵")
post_punk_btn = gr.Button("Post-Punk 🖤")
with gr.Row():
indie_rock_btn = gr.Button("Indie Rock 🎤")
funk_rock_btn = gr.Button("Funk Rock 🕺")
detroit_techno_btn = gr.Button("Detroit Techno 🎛️")
deep_house_btn = gr.Button("Deep House 🏠")
with gr.Column(elem_classes="settings-container"):
gr.Markdown("### Settings")
with gr.Group(elem_classes="group-container"):
cfg_scale = gr.Slider(1.0, 10.0, step=0.1, value=float(loaded.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])), label="CFG Scale")
top_k = gr.Slider(10, 500, step=10, value=int(loaded.get("top_k", DEFAULT_SETTINGS["top_k"])), label="Top-K")
top_p = gr.Slider(0.0, 1.0, step=0.01, value=float(loaded.get("top_p", DEFAULT_SETTINGS["top_p"])), label="Top-P")
temperature = gr.Slider(0.1, 2.0, step=0.01, value=float(loaded.get("temperature", DEFAULT_SETTINGS["temperature"])), label="Temperature")
total_duration = gr.Dropdown(choices=[30, 60, 90, 120], value=int(loaded.get("total_duration", 60)), label="Song Length (seconds)")
bpm = gr.Slider(60, 180, step=1, value=int(loaded.get("bpm", 120)), label="Tempo (BPM)")
drum_beat = gr.Dropdown(choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"], value=str(loaded.get("drum_beat", "none")), label="Drum Beat")
synthesizer = gr.Dropdown(choices=["none", "analog synth", "digital pad", "arpeggiated synth"], value=str(loaded.get("synthesizer", "none")), label="Synthesizer")
rhythmic_steps = gr.Dropdown(choices=["none", "syncopated steps", "steady steps", "complex steps"], value=str(loaded.get("rhythmic_steps", "none")), label="Rhythmic Steps")
bass_style = gr.Dropdown(choices=["none", "slap bass", "deep bass", "melodic bass"], value=str(loaded.get("bass_style", "none")), label="Bass Style")
guitar_style = gr.Dropdown(choices=["none", "distorted", "clean", "jangle"], value=str(loaded.get("guitar_style", "none")), label="Guitar Style")
target_volume = gr.Slider(-30.0, -20.0, step=0.5, value=float(loaded.get("target_volume", -23.0)), label="Target Loudness (dBFS RMS)")
preset = gr.Dropdown(choices=["default", "rock", "techno", "grunge", "indie", "funk_rock"], value=str(loaded.get("preset", "default")), label="Preset")
max_steps = gr.Dropdown(choices=[1000, 1200, 1300, 1500], value=int(loaded.get("max_steps", 1500)), label="Max Steps (per chunk hint)")
bitrate_state = gr.State(value=str(loaded.get("bitrate", "192k")))
sample_rate_state = gr.State(value=str(loaded.get("output_sample_rate", "48000")))
bit_depth_state = gr.State(value=str(loaded.get("bit_depth", "16")))
with gr.Row():
bitrate_128_btn = gr.Button("Bitrate 128k")
bitrate_192_btn = gr.Button("Bitrate 192k")
bitrate_320_btn = gr.Button("Bitrate 320k")
with gr.Row():
sample_rate_22050_btn = gr.Button("SR 22.05k")
sample_rate_44100_btn = gr.Button("SR 44.1k")
sample_rate_48000_btn = gr.Button("SR 48k")
with gr.Row():
bit_depth_16_btn = gr.Button("16-bit")
bit_depth_24_btn = gr.Button("24-bit")
with gr.Row():
gen_btn = gr.Button("Generate Music 🚀")
clr_btn = gr.Button("Clear 🧹")
save_btn = gr.Button("Save Settings 💾")
load_btn = gr.Button("Load Settings 📂")
reset_btn = gr.Button("Reset Defaults ♻️")
with gr.Column(elem_classes="output-container"):
gr.Markdown("### Output")
out_audio = gr.Audio(label="Generated Track", type="filepath")
status_box = gr.Textbox(label="Status", interactive=False)
vram_box = gr.Textbox(label="VRAM Usage", interactive=False, value="")
with gr.Column(elem_classes="logs-container"):
gr.Markdown("### Logs")
log_output = gr.Textbox(label="Last Log File", lines=16, interactive=False)
log_btn = gr.Button("View Last Log")
# Genre buttons -> prompt text (chunk_num fixed = 1 for initial suggestion)
rhcp_btn.click(
set_red_hot_chili_peppers_prompt,
inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, gr.State(value=1)],
outputs=instrumental_prompt
)
nirvana_btn.click(set_nirvana_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
pearl_jam_btn.click(set_pearl_jam_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
soundgarden_btn.click(set_soundgarden_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
foo_fighters_btn.click(set_foo_fighters_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
radiohead_btn.click(set_radiohead_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
classic_rock_btn.click(set_classic_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
alternative_rock_btn.click(set_alternative_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
post_punk_btn.click(set_post_punk_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
indie_rock_btn.click(set_indie_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
funk_rock_btn.click(set_funk_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
detroit_techno_btn.click(set_detroit_techno_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
deep_house_btn.click(set_deep_house_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt)
# Bitrate / SR / Bit depth quick-sets
bitrate_128_btn.click(set_bitrate_128, outputs=bitrate_state)
bitrate_192_btn.click(set_bitrate_192, outputs=bitrate_state)
bitrate_320_btn.click(set_bitrate_320, outputs=bitrate_state)
sample_rate_22050_btn.click(set_sample_rate_22050, outputs=sample_rate_state)
sample_rate_44100_btn.click(set_sample_rate_44100, outputs=sample_rate_state)
sample_rate_48000_btn.click(set_sample_rate_48000, outputs=sample_rate_state)
bit_depth_16_btn.click(set_bit_depth_16, outputs=bit_depth_state)
bit_depth_24_btn.click(set_bit_depth_24, outputs=bit_depth_state)
# Generate
gen_btn.click(
generate_music_wrapper,
inputs=[
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
preset, max_steps, vram_box, bitrate_state, sample_rate_state, bit_depth_state
],
outputs=[out_audio, status_box, vram_box]
)
# Clear
clr_btn.click(
clear_inputs, outputs=[
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state
]
)
# Save / Load / Reset actions
def _save_action(
instrumental_prompt_v, cfg_v, top_k_v, top_p_v, temp_v, dur_v, bpm_v,
drum_v, synth_v, steps_v, bass_v, guitar_v, vol_v, preset_v, maxsteps_v, br_v, sr_v, bd_v
):
s = {
"instrumental_prompt": instrumental_prompt_v,
"cfg_scale": float(cfg_v),
"top_k": int(top_k_v),
"top_p": float(top_p_v),
"temperature": float(temp_v),
"total_duration": int(dur_v),
"bpm": int(bpm_v),
"drum_beat": str(drum_v),
"synthesizer": str(synth_v),
"rhythmic_steps": str(steps_v),
"bass_style": str(bass_v),
"guitar_style": str(guitar_v),
"target_volume": float(vol_v),
"preset": str(preset_v),
"max_steps": int(maxsteps_v),
"bitrate": str(br_v),
"output_sample_rate": str(sr_v),
"bit_depth": str(bd_v)
}
save_settings_to_file(s)
for k, v in s.items():
CURRENT_SETTINGS[k] = v
return "✅ Settings saved."
def _load_action():
s = load_settings_from_file()
for k, v in s.items():
CURRENT_SETTINGS[k] = v
return (
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"],
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"],
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"],
s["bitrate"], s["output_sample_rate"], s["bit_depth"],
"✅ Settings loaded."
)
def _reset_action():
s = DEFAULT_SETTINGS.copy()
save_settings_to_file(s)
for k, v in s.items():
CURRENT_SETTINGS[k] = v
return (
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"],
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"],
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"],
s["bitrate"], s["output_sample_rate"], s["bit_depth"],
"✅ Defaults restored."
)
save_btn.click(
_save_action,
inputs=[
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state
],
outputs=status_box
)
load_btn.click(
_load_action,
outputs=[
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box
]
)
reset_btn.click(
_reset_action,
outputs=[
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box
]
)
# Logs
log_btn.click(get_latest_log, outputs=log_output)
# ======================================================================================
# LAUNCH GRADIO
# ======================================================================================
logger.info("Launching Gradio UI at http://0.0.0.0:9999 ...")
try:
demo.launch(
server_name="0.0.0.0",
server_port=9999,
share=False,
inbrowser=False,
show_error=True
)
except Exception as e:
logger.error(f"Failed to launch Gradio UI: {e}")
logger.error(traceback.format_exc())
sys.exit(1)