|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import gc |
|
|
import re |
|
|
import json |
|
|
import time |
|
|
import mmap |
|
|
import math |
|
|
import torch |
|
|
import random |
|
|
import logging |
|
|
import warnings |
|
|
import traceback |
|
|
import subprocess |
|
|
import numpy as np |
|
|
import torchaudio |
|
|
import gradio as gr |
|
|
import gradio_client.utils |
|
|
from pydub import AudioSegment |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, Dict, Any, List |
|
|
from torch.cuda.amp import autocast |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Body |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
import uvicorn |
|
|
import threading |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_original_get_type = gradio_client.utils.get_type |
|
|
def _patched_get_type(schema): |
|
|
if isinstance(schema, bool): |
|
|
return "boolean" |
|
|
return _original_get_type(schema) |
|
|
gradio_client.utils.get_type = _patched_get_type |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = False |
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
|
|
|
LOG_DIR = "logs" |
|
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
|
LOG_FILE = os.path.join(LOG_DIR, f"musicgen_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") |
|
|
logging.basicConfig( |
|
|
level=logging.DEBUG, |
|
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
|
handlers=[logging.FileHandler(LOG_FILE), logging.StreamHandler(sys.stdout)] |
|
|
) |
|
|
logger = logging.getLogger("ghostai-musicgen") |
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if DEVICE != "cuda": |
|
|
logger.error("CUDA is required. Exiting.") |
|
|
sys.exit(1) |
|
|
logger.info(f"GPU: {torch.cuda.get_device_name(0)}") |
|
|
logger.info("Precision: fp16 model, fp32 CPU audio ops") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SETTINGS_FILE = "settings.json" |
|
|
|
|
|
DEFAULT_SETTINGS: Dict[str, Any] = { |
|
|
"cfg_scale": 5.8, |
|
|
"top_k": 250, |
|
|
"top_p": 0.95, |
|
|
"temperature": 0.90, |
|
|
"total_duration": 60, |
|
|
"bpm": 120, |
|
|
"drum_beat": "none", |
|
|
"synthesizer": "none", |
|
|
"rhythmic_steps": "none", |
|
|
"bass_style": "none", |
|
|
"guitar_style": "none", |
|
|
"target_volume": -23.0, |
|
|
"preset": "default", |
|
|
"max_steps": 1500, |
|
|
"bitrate": "192k", |
|
|
"output_sample_rate": "48000", |
|
|
"bit_depth": "16", |
|
|
"instrumental_prompt": "" |
|
|
} |
|
|
|
|
|
def load_settings_from_file() -> Dict[str, Any]: |
|
|
try: |
|
|
if os.path.exists(SETTINGS_FILE): |
|
|
with open(SETTINGS_FILE, "r") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
for k, v in DEFAULT_SETTINGS.items(): |
|
|
data.setdefault(k, v) |
|
|
logger.info(f"Loaded settings from {SETTINGS_FILE}") |
|
|
return data |
|
|
except Exception as e: |
|
|
logger.error(f"Failed reading {SETTINGS_FILE}: {e}") |
|
|
return DEFAULT_SETTINGS.copy() |
|
|
|
|
|
def save_settings_to_file(settings: Dict[str, Any]) -> None: |
|
|
try: |
|
|
with open(SETTINGS_FILE, "w") as f: |
|
|
json.dump(settings, f, indent=2) |
|
|
logger.info(f"Saved settings to {SETTINGS_FILE}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed saving {SETTINGS_FILE}: {e}") |
|
|
|
|
|
CURRENT_SETTINGS = load_settings_from_file() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_memory() -> Optional[float]: |
|
|
try: |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
torch.cuda.ipc_collect() |
|
|
torch.cuda.synchronize() |
|
|
vram_mb = torch.cuda.memory_allocated() / 1024**2 |
|
|
logger.info(f"Memory cleaned. VRAM={vram_mb:.2f} MB") |
|
|
return vram_mb |
|
|
except Exception as e: |
|
|
logger.error(f"clean_memory failed: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None |
|
|
|
|
|
def check_vram(): |
|
|
try: |
|
|
r = subprocess.run( |
|
|
['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv'], |
|
|
capture_output=True, text=True |
|
|
) |
|
|
lines = r.stdout.splitlines() |
|
|
if len(lines) > 1: |
|
|
used_mb, total_mb = map(int, re.findall(r'\d+', lines[1])) |
|
|
free_mb = total_mb - used_mb |
|
|
logger.info(f"VRAM: used {used_mb} MiB | free {free_mb} MiB | total {total_mb} MiB") |
|
|
if free_mb < 5000: |
|
|
logger.warning(f"Low free VRAM ({free_mb} MiB). Running processes:") |
|
|
procs = subprocess.run( |
|
|
['nvidia-smi', '--query-compute-apps=pid,used_memory', '--format=csv'], |
|
|
capture_output=True, text=True |
|
|
) |
|
|
logger.info(f"\n{procs.stdout}") |
|
|
return free_mb |
|
|
except Exception as e: |
|
|
logger.error(f"check_vram failed: {e}") |
|
|
return None |
|
|
|
|
|
def check_disk_space(path=".") -> bool: |
|
|
try: |
|
|
stat = os.statvfs(path) |
|
|
free_gb = stat.f_bavail * stat.f_frsize / (1024**3) |
|
|
if free_gb < 1.0: |
|
|
logger.warning(f"Low disk space: {free_gb:.2f} GB") |
|
|
return free_gb >= 1.0 |
|
|
except Exception as e: |
|
|
logger.error(f"Disk space check failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_stereo(audio_segment: AudioSegment, sample_rate=48000, sample_width=2) -> AudioSegment: |
|
|
try: |
|
|
if audio_segment.channels != 2: |
|
|
audio_segment = audio_segment.set_channels(2) |
|
|
if audio_segment.frame_rate != sample_rate: |
|
|
audio_segment = audio_segment.set_frame_rate(sample_rate) |
|
|
return audio_segment |
|
|
except Exception as e: |
|
|
logger.error(f"ensure_stereo failed: {e}") |
|
|
return audio_segment |
|
|
|
|
|
def calculate_rms(segment: AudioSegment) -> float: |
|
|
try: |
|
|
samples = np.array(segment.get_array_of_samples(), dtype=np.float32) |
|
|
rms = float(np.sqrt(np.mean(samples**2))) |
|
|
return rms |
|
|
except Exception as e: |
|
|
logger.error(f"calculate_rms failed: {e}") |
|
|
return 0.0 |
|
|
|
|
|
def hard_limit(audio_segment: AudioSegment, limit_db=-3.0, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) |
|
|
limit = 10 ** (limit_db / 20.0) * (2**23 if audio_segment.sample_width == 3 else 32767) |
|
|
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) |
|
|
samples = np.clip(samples, -limit, limit).astype(np.int32 if audio_segment.sample_width == 3 else np.int16) |
|
|
if len(samples) % 2 != 0: |
|
|
samples = samples[:-1] |
|
|
return AudioSegment( |
|
|
samples.tobytes(), |
|
|
frame_rate=sample_rate, |
|
|
sample_width=audio_segment.sample_width, |
|
|
channels=2 |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"hard_limit failed: {e}") |
|
|
return audio_segment |
|
|
|
|
|
def rms_normalize(segment: AudioSegment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
segment = ensure_stereo(segment, sample_rate, segment.sample_width) |
|
|
target_rms = 10 ** (target_rms_db / 20) * (2**23 if segment.sample_width == 3 else 32767) |
|
|
current_rms = calculate_rms(segment) |
|
|
if current_rms > 0: |
|
|
gain_factor = target_rms / current_rms |
|
|
segment = segment.apply_gain(20 * np.log10(max(gain_factor, 1e-6))) |
|
|
segment = hard_limit(segment, limit_db=peak_limit_db, sample_rate=sample_rate) |
|
|
return segment |
|
|
except Exception as e: |
|
|
logger.error(f"rms_normalize failed: {e}") |
|
|
return segment |
|
|
|
|
|
def balance_stereo(audio_segment: AudioSegment, noise_threshold=-40, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) |
|
|
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) |
|
|
if audio_segment.channels != 2: |
|
|
return audio_segment |
|
|
stereo = samples.reshape(-1, 2) |
|
|
db = 20 * np.log10(np.abs(stereo) + 1e-10) |
|
|
mask = db > noise_threshold |
|
|
stereo = stereo * mask |
|
|
left = stereo[:, 0] |
|
|
right = stereo[:, 1] |
|
|
l_rms = np.sqrt(np.mean(left[left != 0] ** 2)) if np.any(left != 0) else 0 |
|
|
r_rms = np.sqrt(np.mean(right[right != 0] ** 2)) if np.any(right != 0) else 0 |
|
|
if l_rms > 0 and r_rms > 0: |
|
|
avg = (l_rms + r_rms) / 2 |
|
|
stereo[:, 0] *= (avg / l_rms) |
|
|
stereo[:, 1] *= (avg / r_rms) |
|
|
out = stereo.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16) |
|
|
if len(out) % 2 != 0: |
|
|
out = out[:-1] |
|
|
return AudioSegment( |
|
|
out.tobytes(), |
|
|
frame_rate=sample_rate, |
|
|
sample_width=audio_segment.sample_width, |
|
|
channels=2 |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"balance_stereo failed: {e}") |
|
|
return audio_segment |
|
|
|
|
|
def apply_noise_gate(audio_segment: AudioSegment, threshold_db=-80, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) |
|
|
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) |
|
|
if audio_segment.channels != 2: |
|
|
return audio_segment |
|
|
stereo = samples.reshape(-1, 2) |
|
|
for _ in range(2): |
|
|
db = 20 * np.log10(np.abs(stereo) + 1e-10) |
|
|
mask = db > threshold_db |
|
|
stereo = stereo * mask |
|
|
out = stereo.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16) |
|
|
if len(out) % 2 != 0: |
|
|
out = out[:-1] |
|
|
return AudioSegment( |
|
|
out.tobytes(), |
|
|
frame_rate=sample_rate, |
|
|
sample_width=audio_segment.sample_width, |
|
|
channels=2 |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"apply_noise_gate failed: {e}") |
|
|
return audio_segment |
|
|
|
|
|
def apply_eq(segment: AudioSegment, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
segment = ensure_stereo(segment, sample_rate, segment.sample_width) |
|
|
segment = segment.high_pass_filter(20) |
|
|
segment = segment.low_pass_filter(8000) |
|
|
segment = segment - 3 |
|
|
segment = segment - 3 |
|
|
segment = segment - 10 |
|
|
return segment |
|
|
except Exception as e: |
|
|
logger.error(f"apply_eq failed: {e}") |
|
|
return segment |
|
|
|
|
|
def apply_fade(segment: AudioSegment, fade_in_duration=500, fade_out_duration=500) -> AudioSegment: |
|
|
try: |
|
|
segment = ensure_stereo(segment, segment.frame_rate, segment.sample_width) |
|
|
segment = segment.fade_in(fade_in_duration).fade_out(fade_out_duration) |
|
|
return segment |
|
|
except Exception as e: |
|
|
logger.error(f"apply_fade failed: {e}") |
|
|
return segment |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num): |
|
|
try: |
|
|
bpm_range = (90, 130) |
|
|
bpm = random.randint(*bpm_range) if bpm == 120 else bpm |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else ", standard rock drums with funk fills" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else ", funky slap bass" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", energetic guitar riffs" |
|
|
base = f"Instrumental alternative rock by Red Hot Chili Peppers{guitar}{bass}{drum}{synth}, funk-rock energy at {bpm} BPM" |
|
|
if chunk_num == 1: |
|
|
return base + ", dynamic intro and expressive verse." |
|
|
return base + ", powerful chorus and energetic outro." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_nirvana_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
bpm_range = (100, 130) |
|
|
bpm = random.randint(*bpm_range) if bpm == 120 else bpm |
|
|
drum = f", {drum_beat} drums, punk energy" if drum_beat != "none" else ", standard rock drums, punk energy" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
chosen_bass = random.choice(['deep bass', 'melodic bass']) if bass_style == "none" else bass_style |
|
|
bass = f", {chosen_bass}" |
|
|
chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style |
|
|
guitar = f", {chosen_guitar}" |
|
|
chosen_rhythm = random.choice(['steady steps', 'dynamic shifts']) if rhythmic_steps == "none" else rhythmic_steps |
|
|
rhythm = f", {chosen_rhythm}" |
|
|
return f"Instrumental grunge by Nirvana{guitar}{bass}{drum}{synth}, raw lo-fi production{rhythm} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_pearl_jam_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
bpm_range = (100, 140) |
|
|
bpm = random.randint(*bpm_range) if bpm == 120 else bpm |
|
|
drum = f", {drum_beat} drums, driving rhythm" if drum_beat != "none" else ", standard rock drums, driving rhythm" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style}, emotional tone" if bass_style != "none" else ", melodic bass, emotional tone" |
|
|
chosen_guitar = random.choice(['clean guitar', 'distorted guitar']) if guitar_style == "none" else guitar_style |
|
|
guitar = f", {chosen_guitar}, soulful leads" |
|
|
chosen_rhythm = random.choice(['steady steps', 'syncopated steps']) if rhythmic_steps == "none" else rhythmic_steps |
|
|
rhythm = f", {chosen_rhythm}" |
|
|
return f"Instrumental grunge by Pearl Jam{guitar}{bass}{drum}{synth}, classic rock influences{rhythm} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_soundgarden_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
bpm_range = (90, 140) |
|
|
bpm = random.randint(*bpm_range) if bpm == 120 else bpm |
|
|
drum = f", {drum_beat} drums, heavy rhythm" if drum_beat != "none" else ", standard rock drums, heavy rhythm" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style}, sludgy tone" if bass_style != "none" else ", deep bass, sludgy tone" |
|
|
guitar = f", {guitar_style}, downtuned riffs, psychedelic vibe" if guitar_style != "none" else ", distorted guitar, downtuned riffs, psychedelic vibe" |
|
|
rhythm = f", {rhythmic_steps}" if rhythmic_steps != "none" else ", complex steps" |
|
|
return f"Instrumental grunge with heavy metal influences by Soundgarden{guitar}{bass}{drum}{synth}{rhythm} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_foo_fighters_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
bpm_range = (110, 150) |
|
|
bpm = random.randint(*bpm_range) if bpm == 120 else bpm |
|
|
drum = f", {drum_beat} drums, powerful drive" if drum_beat != "none" else ", standard rock drums, powerful drive" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style}, supportive tone" if bass_style != "none" else ", melodic bass, supportive tone" |
|
|
chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style |
|
|
guitar = f", {chosen_guitar}, anthemic quality" |
|
|
chosen_rhythm = random.choice(['steady steps', 'driving rhythm']) if rhythmic_steps == "none" else rhythmic_steps |
|
|
rhythm = f", {chosen_rhythm}" |
|
|
return f"Instrumental alternative rock with post-grunge influences by Foo Fighters{guitar}, stadium-ready hooks{bass}{drum}{synth}{rhythm} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_classic_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
bpm_range = (120, 180) |
|
|
bpm = random.randint(*bpm_range) if bpm == 120 else bpm |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else ", double bass drums" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style}" if bass_style != "none" else ", aggressive bass" |
|
|
guitar = f", {guitar_style}, blazing fast riffs" if guitar_style != "none" else ", distorted guitar, blazing fast riffs" |
|
|
rhythm = f", {rhythmic_steps}" if rhythmic_steps != "none" else ", complex steps" |
|
|
return f"Instrumental thrash metal by Metallica{guitar}{bass}{drum}{synth}{rhythm} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_smashing_pumpkins_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else ", lush synths" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else "" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", dreamy guitar" |
|
|
return f"Instrumental alternative rock by Smashing Pumpkins{guitar}{synth}{drum}{bass} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_radiohead_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else ", atmospheric synths" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else ", hypnotic bass" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else "" |
|
|
return f"Instrumental experimental rock by Radiohead{synth}{bass}{drum}{guitar} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_alternative_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else ", melodic bass" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", distorted guitar" |
|
|
return f"Instrumental alternative rock by Pixies{guitar}{bass}{drum}{synth} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_post_punk_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else ", precise drums" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else ", driving bass" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar" |
|
|
return f"Instrumental post-punk by Joy Division{guitar}{bass}{drum}{synth} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_indie_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else ", groovy bass" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar" |
|
|
return f"Instrumental indie rock by Arctic Monkeys{guitar}{bass}{drum}{synth} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_funk_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else ", heavy drums" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else "" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else ", slap bass" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", funky guitar" |
|
|
return f"Instrumental funk rock by Rage Against the Machine{guitar}{bass}{drum}{synth} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_detroit_techno_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else ", four-on-the-floor drums" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else ", pulsing synths" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else ", driving bass" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else "" |
|
|
return f"Instrumental Detroit techno by Juan Atkins{synth}{bass}{drum}{guitar} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def set_deep_house_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
|
try: |
|
|
drum = f", {drum_beat} drums" if drum_beat != "none" else ", steady kick drums" |
|
|
synth = f", {synthesizer}" if synthesizer != "none" else ", warm synths" |
|
|
bass = f", {bass_style} bass" if bass_style != "none" else ", deep bass" |
|
|
guitar = f", {guitar_style} guitar" if guitar_style != "none" else "" |
|
|
return f"Instrumental deep house by Larry Heard{synth}{bass}{drum}{guitar} at {bpm} BPM." |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
PRESETS = { |
|
|
"default": {"cfg_scale": 5.8, "top_k": 250, "top_p": 0.95, "temperature": 0.90}, |
|
|
"rock": {"cfg_scale": 5.8, "top_k": 250, "top_p": 0.95, "temperature": 0.90}, |
|
|
"techno": {"cfg_scale": 5.2, "top_k": 300, "top_p": 0.96, "temperature": 0.95}, |
|
|
"grunge": {"cfg_scale": 6.2, "top_k": 220, "top_p": 0.94, "temperature": 0.90}, |
|
|
"indie": {"cfg_scale": 5.5, "top_k": 240, "top_p": 0.95, "temperature": 0.92}, |
|
|
"funk_rock": {"cfg_scale": 5.8, "top_k": 260, "top_p": 0.96, "temperature": 0.94}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from audiocraft.models import MusicGen |
|
|
except Exception as e: |
|
|
logger.error("audiocraft is required. pip install audiocraft") |
|
|
raise |
|
|
|
|
|
def load_model(): |
|
|
free_vram = check_vram() |
|
|
if free_vram is not None and free_vram < 5000: |
|
|
logger.warning("Low free VRAM; consider closing other apps.") |
|
|
clean_memory() |
|
|
local_model_path = "./models/musicgen-large" |
|
|
if not os.path.exists(local_model_path): |
|
|
logger.error(f"Model path missing: {local_model_path}") |
|
|
sys.exit(1) |
|
|
logger.info("Loading MusicGen (large)...") |
|
|
with autocast(dtype=torch.float16): |
|
|
model = MusicGen.get_pretrained(local_model_path, device=DEVICE) |
|
|
|
|
|
model.set_generation_params(duration=30, two_step_cfg=False) |
|
|
logger.info("MusicGen loaded.") |
|
|
return model |
|
|
|
|
|
musicgen_model = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_latest_log() -> str: |
|
|
try: |
|
|
files = sorted(Path(LOG_DIR).glob("musicgen_log_*.log"), key=os.path.getmtime, reverse=True) |
|
|
if not files: |
|
|
return "No log files found." |
|
|
return files[0].read_text() |
|
|
except Exception as e: |
|
|
return f"Error reading log: {e}" |
|
|
|
|
|
def set_bitrate_128(): return "128k" |
|
|
def set_bitrate_192(): return "192k" |
|
|
def set_bitrate_320(): return "320k" |
|
|
def set_sample_rate_22050(): return "22050" |
|
|
def set_sample_rate_44100(): return "44100" |
|
|
def set_sample_rate_48000(): return "48000" |
|
|
def set_bit_depth_16(): return "16" |
|
|
def set_bit_depth_24(): return "24" |
|
|
|
|
|
def generate_music_wrapper(*args): |
|
|
try: |
|
|
return generate_music(*args) |
|
|
finally: |
|
|
clean_memory() |
|
|
|
|
|
def _export_torch_to_segment(audio_tensor: torch.Tensor, sample_rate: int, bit_depth_int: int) -> Optional[AudioSegment]: |
|
|
"""Helper: save torch stereo float32 to WAV -> load with pydub as segment.""" |
|
|
temp = f"temp_audio_{int(time.time()*1000)}.wav" |
|
|
try: |
|
|
torchaudio.save(temp, audio_tensor, sample_rate, bits_per_sample=bit_depth_int) |
|
|
with open(temp, "rb") as f: |
|
|
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) |
|
|
seg = AudioSegment.from_wav(temp) |
|
|
mm.close() |
|
|
return seg |
|
|
except Exception as e: |
|
|
logger.error(f"_export_torch_to_segment failed: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None |
|
|
finally: |
|
|
try: |
|
|
if os.path.exists(temp): |
|
|
os.remove(temp) |
|
|
except OSError: |
|
|
pass |
|
|
|
|
|
def _crossfade_segments(seg_a: AudioSegment, seg_b: AudioSegment, overlap_ms: int, sample_rate: int, bit_depth_int: int) -> AudioSegment: |
|
|
"""Blend tail of seg_a with head of seg_b using hann window for seamless merge.""" |
|
|
try: |
|
|
seg_a = ensure_stereo(seg_a, sample_rate, seg_a.sample_width) |
|
|
seg_b = ensure_stereo(seg_b, sample_rate, seg_b.sample_width) |
|
|
if overlap_ms <= 0 or len(seg_a) < overlap_ms or len(seg_b) < overlap_ms: |
|
|
return seg_a + seg_b |
|
|
|
|
|
|
|
|
prev_wav = f"tmp_prev_{int(time.time()*1000)}.wav" |
|
|
curr_wav = f"tmp_curr_{int(time.time()*1000)}.wav" |
|
|
try: |
|
|
seg_a[-overlap_ms:].export(prev_wav, format="wav") |
|
|
seg_b[:overlap_ms].export(curr_wav, format="wav") |
|
|
a_audio, sr_a = torchaudio.load(prev_wav) |
|
|
b_audio, sr_b = torchaudio.load(curr_wav) |
|
|
if sr_a != sample_rate: |
|
|
a_audio = torchaudio.functional.resample(a_audio, sr_a, sample_rate, lowpass_filter_width=64) |
|
|
if sr_b != sample_rate: |
|
|
b_audio = torchaudio.functional.resample(b_audio, sr_b, sample_rate, lowpass_filter_width=64) |
|
|
n = min(a_audio.shape[1], b_audio.shape[1]) |
|
|
n = n - (n % 2) |
|
|
if n <= 0: |
|
|
return seg_a + seg_b |
|
|
a = a_audio[:, :n] |
|
|
b = b_audio[:, :n] |
|
|
hann = torch.hann_window(n, periodic=False) |
|
|
fade_in = hann |
|
|
fade_out = hann.flip(0) |
|
|
blended = (a * fade_out + b * fade_in).to(torch.float32) |
|
|
blended = torch.clamp(blended, -1.0, 1.0) |
|
|
|
|
|
|
|
|
scale = (2**23 if bit_depth_int == 24 else 32767) |
|
|
blended_i = (blended * scale).to(torch.int32 if bit_depth_int == 24 else torch.int16) |
|
|
temp_x = f"tmp_cross_{int(time.time()*1000)}.wav" |
|
|
torchaudio.save(temp_x, blended_i, sample_rate, bits_per_sample=bit_depth_int) |
|
|
blended_seg = AudioSegment.from_wav(temp_x) |
|
|
blended_seg = ensure_stereo(blended_seg, sample_rate, blended_seg.sample_width) |
|
|
|
|
|
|
|
|
result = seg_a[:-overlap_ms] + blended_seg + seg_b[overlap_ms:] |
|
|
try: |
|
|
if os.path.exists(temp_x): |
|
|
os.remove(temp_x) |
|
|
except OSError: |
|
|
pass |
|
|
return result |
|
|
finally: |
|
|
for p in [prev_wav, curr_wav]: |
|
|
try: |
|
|
if os.path.exists(p): |
|
|
os.remove(p) |
|
|
except OSError: |
|
|
pass |
|
|
except Exception as e: |
|
|
logger.error(f"_crossfade_segments failed: {e}") |
|
|
return seg_a + seg_b |
|
|
|
|
|
def generate_music( |
|
|
instrumental_prompt: str, |
|
|
cfg_scale: float, |
|
|
top_k: int, |
|
|
top_p: float, |
|
|
temperature: float, |
|
|
total_duration: int, |
|
|
bpm: int, |
|
|
drum_beat: str, |
|
|
synthesizer: str, |
|
|
rhythmic_steps: str, |
|
|
bass_style: str, |
|
|
guitar_style: str, |
|
|
target_volume: float, |
|
|
preset: str, |
|
|
max_steps: str, |
|
|
vram_status_text: str, |
|
|
bitrate: str, |
|
|
output_sample_rate: str, |
|
|
bit_depth: str |
|
|
) -> Tuple[Optional[str], str, str]: |
|
|
global musicgen_model |
|
|
|
|
|
if not instrumental_prompt or not instrumental_prompt.strip(): |
|
|
return None, "⚠️ Please enter a valid instrumental prompt!", vram_status_text |
|
|
|
|
|
try: |
|
|
|
|
|
if preset != "default": |
|
|
p = PRESETS.get(preset, PRESETS["default"]) |
|
|
cfg_scale, top_k, top_p, temperature = p["cfg_scale"], p["top_k"], p["top_p"], p["temperature"] |
|
|
logger.info(f"Preset '{preset}' applied: cfg={cfg_scale} top_k={top_k} top_p={top_p} temp={temperature}") |
|
|
|
|
|
|
|
|
try: |
|
|
output_sr_int = int(output_sample_rate) |
|
|
except: |
|
|
return None, "❌ Invalid output sampling rate; choose 22050/44100/48000", vram_status_text |
|
|
try: |
|
|
bit_depth_int = int(bit_depth) |
|
|
sample_width = 3 if bit_depth_int == 24 else 2 |
|
|
except: |
|
|
return None, "❌ Invalid bit depth; choose 16 or 24", vram_status_text |
|
|
|
|
|
if not check_disk_space(): |
|
|
return None, "⚠️ Low disk space (<1GB).", vram_status_text |
|
|
|
|
|
|
|
|
CHUNK_SEC = 30 |
|
|
total_duration = max(30, min(int(total_duration), 120)) |
|
|
num_chunks = math.ceil(total_duration / CHUNK_SEC) |
|
|
|
|
|
|
|
|
PROCESS_SR = 48000 |
|
|
OVERLAP_SEC = 0.20 |
|
|
channels = 2 |
|
|
|
|
|
|
|
|
seed = random.randint(0, 2**31 - 1) |
|
|
random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
musicgen_model.set_generation_params( |
|
|
duration=CHUNK_SEC, |
|
|
use_sampling=True, |
|
|
top_k=int(top_k), |
|
|
top_p=float(top_p), |
|
|
temperature=float(temperature), |
|
|
cfg_coef=float(cfg_scale), |
|
|
two_step_cfg=False, |
|
|
) |
|
|
|
|
|
vram_status_text = f"Start VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
|
|
|
|
segments: List[AudioSegment] = [] |
|
|
start_time = time.time() |
|
|
|
|
|
for idx in range(num_chunks): |
|
|
chunk_idx = idx + 1 |
|
|
dur = CHUNK_SEC if (idx < num_chunks - 1) else (total_duration - CHUNK_SEC * (num_chunks - 1) or CHUNK_SEC) |
|
|
logger.info(f"Generating chunk {chunk_idx}/{num_chunks} ({dur}s)") |
|
|
|
|
|
|
|
|
if "Red Hot Chili Peppers" in instrumental_prompt: |
|
|
prompt_text = set_red_hot_chili_peppers_prompt( |
|
|
bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_idx |
|
|
) |
|
|
else: |
|
|
prompt_text = instrumental_prompt |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
with autocast(dtype=torch.float16): |
|
|
clean_memory() |
|
|
if idx == 0: |
|
|
audio = musicgen_model.generate([prompt_text], progress=True)[0].cpu() |
|
|
else: |
|
|
|
|
|
prev_seg = segments[-1] |
|
|
prev_seg = apply_noise_gate(prev_seg, threshold_db=-80, sample_rate=PROCESS_SR) |
|
|
prev_seg = balance_stereo(prev_seg, noise_threshold=-40, sample_rate=PROCESS_SR) |
|
|
temp_prev = f"prev_{int(time.time()*1000)}.wav" |
|
|
try: |
|
|
prev_seg.export(temp_prev, format="wav") |
|
|
prev_audio, prev_sr = torchaudio.load(temp_prev) |
|
|
if prev_sr != PROCESS_SR: |
|
|
prev_audio = torchaudio.functional.resample(prev_audio, prev_sr, PROCESS_SR, lowpass_filter_width=64) |
|
|
if prev_audio.shape[0] != 2: |
|
|
prev_audio = prev_audio.repeat(2, 1)[:, :prev_audio.shape[1]] |
|
|
prev_audio = prev_audio.to(DEVICE) |
|
|
tail = prev_audio[:, -int(PROCESS_SR * OVERLAP_SEC):] |
|
|
|
|
|
audio = musicgen_model.generate_continuation( |
|
|
prompt=tail, |
|
|
prompt_sample_rate=PROCESS_SR, |
|
|
descriptions=[prompt_text], |
|
|
progress=True |
|
|
)[0].cpu() |
|
|
del prev_audio, tail |
|
|
finally: |
|
|
try: |
|
|
if os.path.exists(temp_prev): |
|
|
os.remove(temp_prev) |
|
|
except OSError: |
|
|
pass |
|
|
clean_memory() |
|
|
except Exception as e: |
|
|
logger.error(f"Chunk {chunk_idx} generation failed: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None, f"❌ Failed to generate chunk {chunk_idx}: {e}", vram_status_text |
|
|
|
|
|
try: |
|
|
|
|
|
if audio.shape[0] != 2: |
|
|
audio = audio.repeat(2, 1)[:, :audio.shape[1]] |
|
|
audio = audio.to(dtype=torch.float32) |
|
|
audio = torchaudio.functional.resample(audio, 32000, PROCESS_SR, lowpass_filter_width=64) |
|
|
seg = _export_torch_to_segment(audio, PROCESS_SR, bit_depth_int) |
|
|
if seg is None: |
|
|
return None, f"❌ Failed to convert audio for chunk {chunk_idx}", vram_status_text |
|
|
seg = ensure_stereo(seg, PROCESS_SR, sample_width) |
|
|
seg = seg - 15 |
|
|
seg = apply_noise_gate(seg, threshold_db=-80, sample_rate=PROCESS_SR) |
|
|
seg = balance_stereo(seg, noise_threshold=-40, sample_rate=PROCESS_SR) |
|
|
seg = rms_normalize(seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR) |
|
|
seg = apply_eq(seg, sample_rate=PROCESS_SR) |
|
|
|
|
|
seg = seg[:dur * 1000] |
|
|
segments.append(seg) |
|
|
del audio |
|
|
clean_memory() |
|
|
vram_status_text = f"VRAM after chunk {chunk_idx}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
|
except Exception as e: |
|
|
logger.error(f"Post-processing failed (chunk {chunk_idx}): {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None, f"❌ Failed to process chunk {chunk_idx}: {e}", vram_status_text |
|
|
|
|
|
if not segments: |
|
|
return None, "❌ No audio generated.", vram_status_text |
|
|
|
|
|
|
|
|
logger.info("Combining chunks...") |
|
|
final_seg = segments[0] |
|
|
overlap_ms = int(OVERLAP_SEC * 1000) |
|
|
for i in range(1, len(segments)): |
|
|
final_seg = _crossfade_segments(final_seg, segments[i], overlap_ms, PROCESS_SR, bit_depth_int) |
|
|
|
|
|
|
|
|
final_seg = final_seg[:total_duration * 1000] |
|
|
|
|
|
|
|
|
final_seg = apply_noise_gate(final_seg, threshold_db=-80, sample_rate=PROCESS_SR) |
|
|
final_seg = balance_stereo(final_seg, noise_threshold=-40, sample_rate=PROCESS_SR) |
|
|
final_seg = rms_normalize(final_seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR) |
|
|
final_seg = apply_eq(final_seg, sample_rate=PROCESS_SR) |
|
|
final_seg = apply_fade(final_seg, 500, 800) |
|
|
final_seg = final_seg - 10 |
|
|
final_seg = final_seg.set_frame_rate(output_sr_int) |
|
|
|
|
|
|
|
|
mp3_path = f"ghostai_music_{int(time.time())}.mp3" |
|
|
try: |
|
|
clean_memory() |
|
|
final_seg.export(mp3_path, format="mp3", bitrate=bitrate, tags={"title": "GhostAI Instrumental", "artist": "GhostAI"}) |
|
|
except Exception as e: |
|
|
logger.error(f"MP3 export failed ({bitrate}): {e}") |
|
|
fb = f"ghostai_music_fallback_{int(time.time())}.mp3" |
|
|
try: |
|
|
final_seg.export(fb, format="mp3", bitrate="128k") |
|
|
mp3_path = fb |
|
|
except Exception as ee: |
|
|
return None, f"❌ Failed to export MP3: {ee}", vram_status_text |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
vram_status_text = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
|
logger.info(f"Done in {elapsed:.2f}s -> {mp3_path}") |
|
|
return mp3_path, "✅ Done! 30s chunking unified seamlessly. Check output loudness/quality.", vram_status_text |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Generation failed: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None, f"❌ Generation failed: {e}", vram_status_text |
|
|
finally: |
|
|
clean_memory() |
|
|
|
|
|
def clear_inputs(): |
|
|
s = DEFAULT_SETTINGS.copy() |
|
|
return ( |
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], |
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], |
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], |
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BUSY_LOCK = threading.Lock() |
|
|
BUSY_FLAG = False |
|
|
BUSY_FILE = "/tmp/musicgen_busy.lock" |
|
|
CURRENT_JOB: Dict[str, Any] = {"id": None, "start": None} |
|
|
|
|
|
def set_busy(val: bool, job_id: Optional[str] = None): |
|
|
global BUSY_FLAG, CURRENT_JOB |
|
|
with BUSY_LOCK: |
|
|
BUSY_FLAG = val |
|
|
if val: |
|
|
CURRENT_JOB["id"] = job_id or f"job_{int(time.time())}" |
|
|
CURRENT_JOB["start"] = time.time() |
|
|
try: |
|
|
Path(BUSY_FILE).write_text(CURRENT_JOB["id"]) |
|
|
except Exception: |
|
|
pass |
|
|
else: |
|
|
CURRENT_JOB["id"] = None |
|
|
CURRENT_JOB["start"] = None |
|
|
try: |
|
|
if os.path.exists(BUSY_FILE): |
|
|
os.remove(BUSY_FILE) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def is_busy() -> bool: |
|
|
with BUSY_LOCK: |
|
|
return BUSY_FLAG |
|
|
|
|
|
def job_elapsed() -> float: |
|
|
with BUSY_LOCK: |
|
|
if CURRENT_JOB["start"] is None: |
|
|
return 0.0 |
|
|
return time.time() - CURRENT_JOB["start"] |
|
|
|
|
|
class RenderRequest(BaseModel): |
|
|
instrumental_prompt: str |
|
|
cfg_scale: Optional[float] = None |
|
|
top_k: Optional[int] = None |
|
|
top_p: Optional[float] = None |
|
|
temperature: Optional[float] = None |
|
|
total_duration: Optional[int] = None |
|
|
bpm: Optional[int] = None |
|
|
drum_beat: Optional[str] = None |
|
|
synthesizer: Optional[str] = None |
|
|
rhythmic_steps: Optional[str] = None |
|
|
bass_style: Optional[str] = None |
|
|
guitar_style: Optional[str] = None |
|
|
target_volume: Optional[float] = None |
|
|
preset: Optional[str] = None |
|
|
max_steps: Optional[int] = None |
|
|
bitrate: Optional[str] = None |
|
|
output_sample_rate: Optional[str] = None |
|
|
bit_depth: Optional[str] = None |
|
|
|
|
|
class SettingsUpdate(BaseModel): |
|
|
settings: Dict[str, Any] |
|
|
|
|
|
fastapp = FastAPI(title="GhostAI Music Server", version="1.0") |
|
|
fastapp.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], |
|
|
) |
|
|
|
|
|
@fastapp.get("/health") |
|
|
def health(): |
|
|
return {"ok": True, "ts": int(time.time())} |
|
|
|
|
|
@fastapp.get("/status") |
|
|
def status(): |
|
|
busy = is_busy() |
|
|
return { |
|
|
"busy": busy, |
|
|
"job_id": CURRENT_JOB["id"], |
|
|
"since": CURRENT_JOB["start"], |
|
|
"elapsed": job_elapsed(), |
|
|
"lockfile": os.path.exists(BUSY_FILE) |
|
|
} |
|
|
|
|
|
@fastapp.get("/config") |
|
|
def get_config(): |
|
|
return {"defaults": CURRENT_SETTINGS} |
|
|
|
|
|
@fastapp.post("/settings") |
|
|
def set_settings(payload: SettingsUpdate): |
|
|
try: |
|
|
s = CURRENT_SETTINGS.copy() |
|
|
s.update(payload.settings or {}) |
|
|
save_settings_to_file(s) |
|
|
for k, v in s.items(): |
|
|
CURRENT_SETTINGS[k] = v |
|
|
return {"ok": True, "saved": s} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
@fastapp.post("/render") |
|
|
def render(req: RenderRequest): |
|
|
if is_busy(): |
|
|
raise HTTPException(status_code=409, detail="Server busy") |
|
|
job_id = f"render_{int(time.time())}" |
|
|
set_busy(True, job_id) |
|
|
try: |
|
|
s = CURRENT_SETTINGS.copy() |
|
|
|
|
|
for k, v in req.dict().items(): |
|
|
if v is not None: |
|
|
s[k] = v |
|
|
mp3, msg, vram = generate_music( |
|
|
s.get("instrumental_prompt", req.instrumental_prompt), |
|
|
float(s.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])), |
|
|
int(s.get("top_k", DEFAULT_SETTINGS["top_k"])), |
|
|
float(s.get("top_p", DEFAULT_SETTINGS["top_p"])), |
|
|
float(s.get("temperature", DEFAULT_SETTINGS["temperature"])), |
|
|
int(s.get("total_duration", DEFAULT_SETTINGS["total_duration"])), |
|
|
int(s.get("bpm", DEFAULT_SETTINGS["bpm"])), |
|
|
str(s.get("drum_beat", DEFAULT_SETTINGS["drum_beat"])), |
|
|
str(s.get("synthesizer", DEFAULT_SETTINGS["synthesizer"])), |
|
|
str(s.get("rhythmic_steps", DEFAULT_SETTINGS["rhythmic_steps"])), |
|
|
str(s.get("bass_style", DEFAULT_SETTINGS["bass_style"])), |
|
|
str(s.get("guitar_style", DEFAULT_SETTINGS["guitar_style"])), |
|
|
float(s.get("target_volume", DEFAULT_SETTINGS["target_volume"])), |
|
|
str(s.get("preset", DEFAULT_SETTINGS["preset"])), |
|
|
str(s.get("max_steps", DEFAULT_SETTINGS["max_steps"])), |
|
|
"", |
|
|
str(s.get("bitrate", DEFAULT_SETTINGS["bitrate"])), |
|
|
str(s.get("output_sample_rate", DEFAULT_SETTINGS["output_sample_rate"])), |
|
|
str(s.get("bit_depth", DEFAULT_SETTINGS["bit_depth"])) |
|
|
) |
|
|
if not mp3: |
|
|
raise HTTPException(status_code=500, detail=msg) |
|
|
return {"ok": True, "job_id": job_id, "path": mp3, "status": msg, "vram": vram} |
|
|
finally: |
|
|
set_busy(False, None) |
|
|
|
|
|
def _start_fastapi(): |
|
|
uvicorn.run(fastapp, host="0.0.0.0", port=8555, log_level="info") |
|
|
|
|
|
api_thread = threading.Thread(target=_start_fastapi, daemon=True) |
|
|
api_thread.start() |
|
|
logger.info("FastAPI server started on http://0.0.0.0:8555") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CSS = """ |
|
|
:root { color-scheme: dark; } |
|
|
body, .gradio-container, .block, .tabs, .panel, .form, .wrap { background: #0B0B0D !important; color: #FFFFFF !important; } |
|
|
* { color: #FFFFFF !important; } |
|
|
label, p, span, h1, h2, h3, h4, h5, h6 { color: #FFFFFF !important; } |
|
|
input, textarea, select { background: #15151A !important; color: #FFFFFF !important; border: 1px solid #2B2B33 !important; } |
|
|
button { background: #1F6FEB !important; color: #FFFFFF !important; border: 2px solid transparent !important; border-radius: 8px !important; padding: 10px 16px !important; font-weight: 700 !important; } |
|
|
button:hover { background: #2D7BFF !important; } |
|
|
button:focus { outline: 3px solid #00C853 !important; } |
|
|
.slider > input { accent-color: #FFD600 !important; } |
|
|
.group-container { border: 1px solid #2B2B33; border-radius: 10px; padding: 16px; } |
|
|
.header { text-align:center; padding: 12px 16px; border-bottom: 2px solid #00C853; } |
|
|
.header h1 { font-size: 28px; margin: 6px 0 0 0; } |
|
|
.header .logo { font-size: 44px; } |
|
|
""" |
|
|
|
|
|
loaded = CURRENT_SETTINGS |
|
|
|
|
|
logger.info("Building Gradio UI...") |
|
|
with gr.Blocks(css=CSS, analytics_enabled=False, title="GhostAI Music Generator") as demo: |
|
|
gr.Markdown(f""" |
|
|
<div class="header" role="banner" aria-label="GhostAI Music Generator"> |
|
|
<div class="logo">👻</div> |
|
|
<h1>GhostAI Music Generator</h1> |
|
|
<p>30s chunking, seamless joins, 1-minute ready, API status & settings save</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Column(elem_classes="input-container"): |
|
|
gr.Markdown("### Prompt") |
|
|
instrumental_prompt = gr.Textbox( |
|
|
label="Instrumental Prompt", |
|
|
placeholder="Type your instrumental prompt or use genre buttons below", |
|
|
lines=4, |
|
|
value=loaded.get("instrumental_prompt", ""), |
|
|
) |
|
|
with gr.Row(): |
|
|
rhcp_btn = gr.Button("Red Hot Chili Peppers 🌶️") |
|
|
nirvana_btn = gr.Button("Nirvana 🎸") |
|
|
pearl_jam_btn = gr.Button("Pearl Jam 🦪") |
|
|
soundgarden_btn = gr.Button("Soundgarden 🌑") |
|
|
foo_fighters_btn = gr.Button("Foo Fighters 🤘") |
|
|
with gr.Row(): |
|
|
smashing_pumpkins_btn = gr.Button("Smashing Pumpkins 🎃") |
|
|
radiohead_btn = gr.Button("Radiohead 🧠") |
|
|
classic_rock_btn = gr.Button("Metallica Heavy Metal 🎸") |
|
|
alternative_rock_btn = gr.Button("Alternative Rock 🎵") |
|
|
post_punk_btn = gr.Button("Post-Punk 🖤") |
|
|
with gr.Row(): |
|
|
indie_rock_btn = gr.Button("Indie Rock 🎤") |
|
|
funk_rock_btn = gr.Button("Funk Rock 🕺") |
|
|
detroit_techno_btn = gr.Button("Detroit Techno 🎛️") |
|
|
deep_house_btn = gr.Button("Deep House 🏠") |
|
|
|
|
|
with gr.Column(elem_classes="settings-container"): |
|
|
gr.Markdown("### Settings") |
|
|
with gr.Group(elem_classes="group-container"): |
|
|
cfg_scale = gr.Slider(1.0, 10.0, step=0.1, value=float(loaded.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])), label="CFG Scale") |
|
|
top_k = gr.Slider(10, 500, step=10, value=int(loaded.get("top_k", DEFAULT_SETTINGS["top_k"])), label="Top-K") |
|
|
top_p = gr.Slider(0.0, 1.0, step=0.01, value=float(loaded.get("top_p", DEFAULT_SETTINGS["top_p"])), label="Top-P") |
|
|
temperature = gr.Slider(0.1, 2.0, step=0.01, value=float(loaded.get("temperature", DEFAULT_SETTINGS["temperature"])), label="Temperature") |
|
|
total_duration = gr.Dropdown(choices=[30, 60, 90, 120], value=int(loaded.get("total_duration", 60)), label="Song Length (seconds)") |
|
|
bpm = gr.Slider(60, 180, step=1, value=int(loaded.get("bpm", 120)), label="Tempo (BPM)") |
|
|
drum_beat = gr.Dropdown(choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"], value=str(loaded.get("drum_beat", "none")), label="Drum Beat") |
|
|
synthesizer = gr.Dropdown(choices=["none", "analog synth", "digital pad", "arpeggiated synth"], value=str(loaded.get("synthesizer", "none")), label="Synthesizer") |
|
|
rhythmic_steps = gr.Dropdown(choices=["none", "syncopated steps", "steady steps", "complex steps"], value=str(loaded.get("rhythmic_steps", "none")), label="Rhythmic Steps") |
|
|
bass_style = gr.Dropdown(choices=["none", "slap bass", "deep bass", "melodic bass"], value=str(loaded.get("bass_style", "none")), label="Bass Style") |
|
|
guitar_style = gr.Dropdown(choices=["none", "distorted", "clean", "jangle"], value=str(loaded.get("guitar_style", "none")), label="Guitar Style") |
|
|
target_volume = gr.Slider(-30.0, -20.0, step=0.5, value=float(loaded.get("target_volume", -23.0)), label="Target Loudness (dBFS RMS)") |
|
|
preset = gr.Dropdown(choices=["default", "rock", "techno", "grunge", "indie", "funk_rock"], value=str(loaded.get("preset", "default")), label="Preset") |
|
|
max_steps = gr.Dropdown(choices=[1000, 1200, 1300, 1500], value=int(loaded.get("max_steps", 1500)), label="Max Steps (per chunk hint)") |
|
|
|
|
|
bitrate_state = gr.State(value=str(loaded.get("bitrate", "192k"))) |
|
|
sample_rate_state = gr.State(value=str(loaded.get("output_sample_rate", "48000"))) |
|
|
bit_depth_state = gr.State(value=str(loaded.get("bit_depth", "16"))) |
|
|
|
|
|
with gr.Row(): |
|
|
bitrate_128_btn = gr.Button("Bitrate 128k") |
|
|
bitrate_192_btn = gr.Button("Bitrate 192k") |
|
|
bitrate_320_btn = gr.Button("Bitrate 320k") |
|
|
with gr.Row(): |
|
|
sample_rate_22050_btn = gr.Button("SR 22.05k") |
|
|
sample_rate_44100_btn = gr.Button("SR 44.1k") |
|
|
sample_rate_48000_btn = gr.Button("SR 48k") |
|
|
with gr.Row(): |
|
|
bit_depth_16_btn = gr.Button("16-bit") |
|
|
bit_depth_24_btn = gr.Button("24-bit") |
|
|
|
|
|
with gr.Row(): |
|
|
gen_btn = gr.Button("Generate Music 🚀") |
|
|
clr_btn = gr.Button("Clear 🧹") |
|
|
save_btn = gr.Button("Save Settings 💾") |
|
|
load_btn = gr.Button("Load Settings 📂") |
|
|
reset_btn = gr.Button("Reset Defaults ♻️") |
|
|
|
|
|
with gr.Column(elem_classes="output-container"): |
|
|
gr.Markdown("### Output") |
|
|
out_audio = gr.Audio(label="Generated Track", type="filepath") |
|
|
status_box = gr.Textbox(label="Status", interactive=False) |
|
|
vram_box = gr.Textbox(label="VRAM Usage", interactive=False, value="") |
|
|
|
|
|
with gr.Column(elem_classes="logs-container"): |
|
|
gr.Markdown("### Logs") |
|
|
log_output = gr.Textbox(label="Last Log File", lines=16, interactive=False) |
|
|
log_btn = gr.Button("View Last Log") |
|
|
|
|
|
|
|
|
rhcp_btn.click( |
|
|
set_red_hot_chili_peppers_prompt, |
|
|
inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, gr.State(value=1)], |
|
|
outputs=instrumental_prompt |
|
|
) |
|
|
nirvana_btn.click(set_nirvana_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
pearl_jam_btn.click(set_pearl_jam_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
soundgarden_btn.click(set_soundgarden_grunge_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
foo_fighters_btn.click(set_foo_fighters_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
radiohead_btn.click(set_radiohead_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
classic_rock_btn.click(set_classic_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
alternative_rock_btn.click(set_alternative_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
post_punk_btn.click(set_post_punk_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
indie_rock_btn.click(set_indie_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
funk_rock_btn.click(set_funk_rock_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
detroit_techno_btn.click(set_detroit_techno_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
deep_house_btn.click(set_deep_house_prompt, [bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], instrumental_prompt) |
|
|
|
|
|
|
|
|
bitrate_128_btn.click(set_bitrate_128, outputs=bitrate_state) |
|
|
bitrate_192_btn.click(set_bitrate_192, outputs=bitrate_state) |
|
|
bitrate_320_btn.click(set_bitrate_320, outputs=bitrate_state) |
|
|
sample_rate_22050_btn.click(set_sample_rate_22050, outputs=sample_rate_state) |
|
|
sample_rate_44100_btn.click(set_sample_rate_44100, outputs=sample_rate_state) |
|
|
sample_rate_48000_btn.click(set_sample_rate_48000, outputs=sample_rate_state) |
|
|
bit_depth_16_btn.click(set_bit_depth_16, outputs=bit_depth_state) |
|
|
bit_depth_24_btn.click(set_bit_depth_24, outputs=bit_depth_state) |
|
|
|
|
|
|
|
|
gen_btn.click( |
|
|
generate_music_wrapper, |
|
|
inputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, vram_box, bitrate_state, sample_rate_state, bit_depth_state |
|
|
], |
|
|
outputs=[out_audio, status_box, vram_box] |
|
|
) |
|
|
|
|
|
|
|
|
clr_btn.click( |
|
|
clear_inputs, outputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def _save_action( |
|
|
instrumental_prompt_v, cfg_v, top_k_v, top_p_v, temp_v, dur_v, bpm_v, |
|
|
drum_v, synth_v, steps_v, bass_v, guitar_v, vol_v, preset_v, maxsteps_v, br_v, sr_v, bd_v |
|
|
): |
|
|
s = { |
|
|
"instrumental_prompt": instrumental_prompt_v, |
|
|
"cfg_scale": float(cfg_v), |
|
|
"top_k": int(top_k_v), |
|
|
"top_p": float(top_p_v), |
|
|
"temperature": float(temp_v), |
|
|
"total_duration": int(dur_v), |
|
|
"bpm": int(bpm_v), |
|
|
"drum_beat": str(drum_v), |
|
|
"synthesizer": str(synth_v), |
|
|
"rhythmic_steps": str(steps_v), |
|
|
"bass_style": str(bass_v), |
|
|
"guitar_style": str(guitar_v), |
|
|
"target_volume": float(vol_v), |
|
|
"preset": str(preset_v), |
|
|
"max_steps": int(maxsteps_v), |
|
|
"bitrate": str(br_v), |
|
|
"output_sample_rate": str(sr_v), |
|
|
"bit_depth": str(bd_v) |
|
|
} |
|
|
save_settings_to_file(s) |
|
|
for k, v in s.items(): |
|
|
CURRENT_SETTINGS[k] = v |
|
|
return "✅ Settings saved." |
|
|
|
|
|
def _load_action(): |
|
|
s = load_settings_from_file() |
|
|
for k, v in s.items(): |
|
|
CURRENT_SETTINGS[k] = v |
|
|
return ( |
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], |
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], |
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], |
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"], |
|
|
"✅ Settings loaded." |
|
|
) |
|
|
|
|
|
def _reset_action(): |
|
|
s = DEFAULT_SETTINGS.copy() |
|
|
save_settings_to_file(s) |
|
|
for k, v in s.items(): |
|
|
CURRENT_SETTINGS[k] = v |
|
|
return ( |
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], |
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], |
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], |
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"], |
|
|
"✅ Defaults restored." |
|
|
) |
|
|
|
|
|
save_btn.click( |
|
|
_save_action, |
|
|
inputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state |
|
|
], |
|
|
outputs=status_box |
|
|
) |
|
|
|
|
|
load_btn.click( |
|
|
_load_action, |
|
|
outputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box |
|
|
] |
|
|
) |
|
|
|
|
|
reset_btn.click( |
|
|
_reset_action, |
|
|
outputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
log_btn.click(get_latest_log, outputs=log_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Launching Gradio UI at http://0.0.0.0:9999 ...") |
|
|
try: |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=9999, |
|
|
share=False, |
|
|
inbrowser=False, |
|
|
show_error=True |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to launch Gradio UI: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
sys.exit(1) |
|
|
|