|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import gc
|
|
|
import re
|
|
|
import json
|
|
|
import time
|
|
|
import mmap
|
|
|
import math
|
|
|
import torch
|
|
|
import random
|
|
|
import logging
|
|
|
import warnings
|
|
|
import traceback
|
|
|
import subprocess
|
|
|
import numpy as np
|
|
|
import torchaudio
|
|
|
import gradio as gr
|
|
|
import gradio_client.utils
|
|
|
import threading
|
|
|
import configparser
|
|
|
from pydub import AudioSegment
|
|
|
from datetime import datetime
|
|
|
from pathlib import Path
|
|
|
from typing import Optional, Tuple, Dict, Any, List
|
|
|
from torch.cuda.amp import autocast
|
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from pydantic import BaseModel
|
|
|
import uvicorn
|
|
|
|
|
|
from colorama import init as colorama_init, Fore, Style
|
|
|
|
|
|
RELEASE = "v1.3.2"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_original_get_type = gradio_client.utils.get_type
|
|
|
def _patched_get_type(schema):
|
|
|
if isinstance(schema, bool):
|
|
|
return "boolean"
|
|
|
return _original_get_type(schema)
|
|
|
gradio_client.utils.get_type = _patched_get_type
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).parent.resolve()
|
|
|
LOG_DIR = BASE_DIR / "logs"
|
|
|
MP3_DIR = BASE_DIR / "mp3"
|
|
|
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
MP3_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
LOG_FILE = LOG_DIR / "ghostai_musicgen.log"
|
|
|
logger = logging.getLogger("ghostai-musicgen")
|
|
|
logger.setLevel(logging.DEBUG)
|
|
|
file_handler = RotatingFileHandler(LOG_FILE, maxBytes=5 * 1024 * 1024, backupCount=0, encoding="utf-8")
|
|
|
file_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s"))
|
|
|
console_handler = logging.StreamHandler(sys.stdout)
|
|
|
console_handler.setFormatter(logging.Formatter("%(message)s"))
|
|
|
logger.addHandler(file_handler)
|
|
|
logger.addHandler(console_handler)
|
|
|
|
|
|
|
|
|
colorama_init()
|
|
|
print(f"{Fore.CYAN}GhostAI Music Generator {Fore.MAGENTA}{RELEASE}{Fore.RESET} β {Fore.GREEN}Booting...{Fore.RESET}")
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
if DEVICE != "cuda":
|
|
|
print(f"{Fore.RED}CUDA not available. Exiting.{Fore.RESET}")
|
|
|
logger.error("CUDA is required. Exiting.")
|
|
|
sys.exit(1)
|
|
|
|
|
|
gpu_name = torch.cuda.get_device_name(0)
|
|
|
print(f"{Fore.YELLOW}GPU:{Fore.RESET} {gpu_name}")
|
|
|
print(f"{Fore.YELLOW}Precision:{Fore.RESET} fp16 (model) / fp32 (CPU audio ops)")
|
|
|
|
|
|
|
|
|
CSS_FILE = BASE_DIR / "styles.css"
|
|
|
PROMPTS_INI = BASE_DIR / "prompts.ini"
|
|
|
EXAMPLES_MD = BASE_DIR / "examples.md"
|
|
|
SETTINGS_FILE = BASE_DIR / "settings.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_SETTINGS: Dict[str, Any] = {
|
|
|
"cfg_scale": 5.8,
|
|
|
"top_k": 250,
|
|
|
"top_p": 0.95,
|
|
|
"temperature": 0.90,
|
|
|
"total_duration": 60,
|
|
|
"bpm": 120,
|
|
|
"drum_beat": "none",
|
|
|
"synthesizer": "none",
|
|
|
"rhythmic_steps": "none",
|
|
|
"bass_style": "none",
|
|
|
"guitar_style": "none",
|
|
|
"target_volume": -23.0,
|
|
|
"preset": "default",
|
|
|
"max_steps": 1500,
|
|
|
"bitrate": "192k",
|
|
|
"output_sample_rate": "48000",
|
|
|
"bit_depth": "16",
|
|
|
"instrumental_prompt": ""
|
|
|
}
|
|
|
|
|
|
def load_settings() -> Dict[str, Any]:
|
|
|
if SETTINGS_FILE.exists():
|
|
|
try:
|
|
|
data = json.loads(SETTINGS_FILE.read_text())
|
|
|
for k, v in DEFAULT_SETTINGS.items():
|
|
|
data.setdefault(k, v)
|
|
|
logger.info("Settings loaded.")
|
|
|
return data
|
|
|
except Exception as e:
|
|
|
logger.error(f"Settings read failed: {e}")
|
|
|
return DEFAULT_SETTINGS.copy()
|
|
|
|
|
|
def save_settings(s: Dict[str, Any]) -> None:
|
|
|
try:
|
|
|
SETTINGS_FILE.write_text(json.dumps(s, indent=2))
|
|
|
logger.info("Settings saved.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Settings write failed: {e}")
|
|
|
|
|
|
CURRENT_SETTINGS = load_settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_memory() -> Optional[float]:
|
|
|
try:
|
|
|
torch.cuda.empty_cache()
|
|
|
gc.collect()
|
|
|
torch.cuda.ipc_collect()
|
|
|
torch.cuda.synchronize()
|
|
|
vram_mb = torch.cuda.memory_allocated() / 1024**2
|
|
|
logger.debug(f"Memory cleaned. VRAM={vram_mb:.2f} MB")
|
|
|
return vram_mb
|
|
|
except Exception as e:
|
|
|
logger.error(f"clean_memory failed: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
return None
|
|
|
|
|
|
def check_vram():
|
|
|
try:
|
|
|
r = subprocess.run(
|
|
|
["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv"],
|
|
|
capture_output=True, text=True
|
|
|
)
|
|
|
lines = r.stdout.splitlines()
|
|
|
if len(lines) > 1:
|
|
|
used_mb, total_mb = map(int, re.findall(r"\d+", lines[1]))
|
|
|
free_mb = total_mb - used_mb
|
|
|
logger.info(f"VRAM: used {used_mb} MiB | free {free_mb} MiB | total {total_mb} MiB")
|
|
|
if free_mb < 5000:
|
|
|
procs = subprocess.run(
|
|
|
["nvidia-smi", "--query-compute-apps=pid,used_memory", "--format=csv"],
|
|
|
capture_output=True, text=True
|
|
|
)
|
|
|
logger.info(f"GPU processes:\n{procs.stdout}")
|
|
|
return free_mb
|
|
|
except Exception as e:
|
|
|
logger.error(f"check_vram failed: {e}")
|
|
|
return None
|
|
|
|
|
|
def check_disk_space(path=".") -> bool:
|
|
|
try:
|
|
|
stat = os.statvfs(path)
|
|
|
free_gb = stat.f_bavail * stat.f_frsize / (1024**3)
|
|
|
if free_gb < 1.0:
|
|
|
logger.warning(f"Low disk space: {free_gb:.2f} GB")
|
|
|
return free_gb >= 1.0
|
|
|
except Exception as e:
|
|
|
logger.error(f"Disk space check failed: {e}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_stereo(seg: AudioSegment, sample_rate=48000, sample_width=2) -> AudioSegment:
|
|
|
try:
|
|
|
if seg.channels != 2:
|
|
|
seg = seg.set_channels(2)
|
|
|
if seg.frame_rate != sample_rate:
|
|
|
seg = seg.set_frame_rate(sample_rate)
|
|
|
return seg
|
|
|
except Exception as e:
|
|
|
logger.error(f"ensure_stereo failed: {e}")
|
|
|
return seg
|
|
|
|
|
|
def calculate_rms(seg: AudioSegment) -> float:
|
|
|
try:
|
|
|
samples = np.array(seg.get_array_of_samples(), dtype=np.float32)
|
|
|
return float(np.sqrt(np.mean(samples**2)))
|
|
|
except Exception:
|
|
|
return 0.0
|
|
|
|
|
|
def hard_limit(seg: AudioSegment, limit_db=-3.0, sample_rate=48000) -> AudioSegment:
|
|
|
try:
|
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width)
|
|
|
limit = 10 ** (limit_db / 20.0) * (2**23 if seg.sample_width == 3 else 32767)
|
|
|
samples = np.array(seg.get_array_of_samples(), dtype=np.float32)
|
|
|
samples = np.clip(samples, -limit, limit).astype(np.int32 if seg.sample_width == 3 else np.int16)
|
|
|
if len(samples) % 2 != 0:
|
|
|
samples = samples[:-1]
|
|
|
return AudioSegment(
|
|
|
samples.tobytes(),
|
|
|
frame_rate=sample_rate,
|
|
|
sample_width=seg.sample_width,
|
|
|
channels=2
|
|
|
)
|
|
|
except Exception as e:
|
|
|
logger.error(f"hard_limit failed: {e}")
|
|
|
return seg
|
|
|
|
|
|
def rms_normalize(seg: AudioSegment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=48000) -> AudioSegment:
|
|
|
try:
|
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width)
|
|
|
target_rms = 10 ** (target_rms_db / 20) * (2**23 if seg.sample_width == 3 else 32767)
|
|
|
current = calculate_rms(seg)
|
|
|
if current > 0:
|
|
|
gain = target_rms / current
|
|
|
seg = seg.apply_gain(20 * np.log10(max(gain, 1e-6)))
|
|
|
return hard_limit(seg, peak_limit_db, sample_rate)
|
|
|
except Exception as e:
|
|
|
logger.error(f"rms_normalize failed: {e}")
|
|
|
return seg
|
|
|
|
|
|
def balance_stereo(seg: AudioSegment, noise_threshold=-40, sample_rate=48000) -> AudioSegment:
|
|
|
try:
|
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width)
|
|
|
arr = np.array(seg.get_array_of_samples(), dtype=np.float32)
|
|
|
if seg.channels != 2:
|
|
|
return seg
|
|
|
stereo = arr.reshape(-1, 2)
|
|
|
db = 20 * np.log10(np.abs(stereo) + 1e-10)
|
|
|
mask = db > noise_threshold
|
|
|
stereo = stereo * mask
|
|
|
left, right = stereo[:, 0], stereo[:, 1]
|
|
|
l_rms = np.sqrt(np.mean(left[left != 0] ** 2)) if np.any(left != 0) else 0
|
|
|
r_rms = np.sqrt(np.mean(right[right != 0] ** 2)) if np.any(right != 0) else 0
|
|
|
if l_rms > 0 and r_rms > 0:
|
|
|
avg = (l_rms + r_rms) / 2
|
|
|
stereo[:, 0] *= (avg / l_rms)
|
|
|
stereo[:, 1] *= (avg / r_rms)
|
|
|
out = stereo.flatten().astype(np.int32 if seg.sample_width == 3 else np.int16)
|
|
|
if len(out) % 2 != 0:
|
|
|
out = out[:-1]
|
|
|
return AudioSegment(out.tobytes(), frame_rate=sample_rate, sample_width=seg.sample_width, channels=2)
|
|
|
except Exception as e:
|
|
|
logger.error(f"balance_stereo failed: {e}")
|
|
|
return seg
|
|
|
|
|
|
def apply_noise_gate(seg: AudioSegment, threshold_db=-80, sample_rate=48000) -> AudioSegment:
|
|
|
try:
|
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width)
|
|
|
arr = np.array(seg.get_array_of_samples(), dtype=np.float32)
|
|
|
if seg.channels != 2:
|
|
|
return seg
|
|
|
stereo = arr.reshape(-1, 2)
|
|
|
for _ in range(2):
|
|
|
db = 20 * np.log10(np.abs(stereo) + 1e-10)
|
|
|
stereo = stereo * (db > threshold_db)
|
|
|
out = stereo.flatten().astype(np.int32 if seg.sample_width == 3 else np.int16)
|
|
|
if len(out) % 2 != 0:
|
|
|
out = out[:-1]
|
|
|
return AudioSegment(out.tobytes(), frame_rate=sample_rate, sample_width=seg.sample_width, channels=2)
|
|
|
except Exception as e:
|
|
|
logger.error(f"apply_noise_gate failed: {e}")
|
|
|
return seg
|
|
|
|
|
|
def apply_eq(seg: AudioSegment, sample_rate=48000) -> AudioSegment:
|
|
|
try:
|
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width)
|
|
|
seg = seg.high_pass_filter(20)
|
|
|
seg = seg.low_pass_filter(8000)
|
|
|
seg = seg - 3
|
|
|
seg = seg - 3
|
|
|
seg = seg - 10
|
|
|
return seg
|
|
|
except Exception as e:
|
|
|
logger.error(f"apply_eq failed: {e}")
|
|
|
return seg
|
|
|
|
|
|
def apply_fade(seg: AudioSegment, fade_in=500, fade_out=800) -> AudioSegment:
|
|
|
try:
|
|
|
seg = ensure_stereo(seg, seg.frame_rate, seg.sample_width)
|
|
|
return seg.fade_in(fade_in).fade_out(fade_out)
|
|
|
except Exception as e:
|
|
|
logger.error(f"apply_fade failed: {e}")
|
|
|
return seg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SafeFormatDict(dict):
|
|
|
def __missing__(self, key):
|
|
|
return ""
|
|
|
|
|
|
class StylesConfig:
|
|
|
def __init__(self, path: Path):
|
|
|
self.path = path
|
|
|
self.cfg = configparser.ConfigParser(interpolation=None)
|
|
|
self.mtime = 0.0
|
|
|
self.styles: Dict[str, Dict[str, Any]] = {}
|
|
|
self._load()
|
|
|
|
|
|
def _load(self):
|
|
|
if not self.path.exists():
|
|
|
logger.error(f"prompts.ini not found: {self.path}")
|
|
|
self.cfg = configparser.ConfigParser(interpolation=None)
|
|
|
self.styles = {}
|
|
|
self.mtime = 0.0
|
|
|
return
|
|
|
self.cfg.read(self.path, encoding="utf-8")
|
|
|
self.styles = {}
|
|
|
for sec in self.cfg.sections():
|
|
|
d: Dict[str, Any] = {k: v for k, v in self.cfg.items(sec)}
|
|
|
|
|
|
listish = {
|
|
|
"drum_beat", "synthesizer", "rhythmic_steps", "bass_style", "guitar_style",
|
|
|
"variations", "mood", "genre", "key", "scale", "feel", "instrument",
|
|
|
"lead", "pad", "arp", "drums", "bass", "guitar", "strings", "brass", "woodwinds",
|
|
|
"structure"
|
|
|
}
|
|
|
for key in listish:
|
|
|
if key in d and isinstance(d[key], str):
|
|
|
d[key] = [s.strip() for s in d[key].split(",") if s.strip()]
|
|
|
self.styles[sec] = d
|
|
|
self.mtime = self.path.stat().st_mtime
|
|
|
logger.info(f"Loaded {len(self.styles)} styles from prompts.ini")
|
|
|
|
|
|
def maybe_reload(self):
|
|
|
if self.path.exists():
|
|
|
mt = self.path.stat().st_mtime
|
|
|
if mt != self.mtime:
|
|
|
self._load()
|
|
|
|
|
|
def list_styles(self) -> List[str]:
|
|
|
self.maybe_reload()
|
|
|
return list(self.styles.keys())
|
|
|
|
|
|
def _pick_from_list(self, vals: Any) -> str:
|
|
|
if isinstance(vals, list):
|
|
|
return random.choice(vals) if vals else ""
|
|
|
return str(vals or "")
|
|
|
|
|
|
def build_prompt(
|
|
|
self,
|
|
|
style: str,
|
|
|
bpm: int,
|
|
|
chunk_num: int = 1,
|
|
|
drum_beat: str = "none",
|
|
|
synthesizer: str = "none",
|
|
|
rhythmic_steps: str = "none",
|
|
|
bass_style: str = "none",
|
|
|
guitar_style: str = "none"
|
|
|
) -> str:
|
|
|
self.maybe_reload()
|
|
|
if style not in self.styles:
|
|
|
return ""
|
|
|
s = self.styles[style]
|
|
|
|
|
|
bpm_min = int(s.get("bpm_min", "100"))
|
|
|
bpm_max = int(s.get("bpm_max", "140"))
|
|
|
final_bpm = bpm if bpm != 120 else random.randint(bpm_min, bpm_max)
|
|
|
|
|
|
def choose(field_name: str, incoming: str) -> str:
|
|
|
if incoming and incoming != "none":
|
|
|
return incoming
|
|
|
return self._pick_from_list(s.get(field_name, [])) or ""
|
|
|
|
|
|
d = choose("drum_beat", drum_beat)
|
|
|
syn = choose("synthesizer", synthesizer)
|
|
|
r = choose("rhythmic_steps", rhythmic_steps)
|
|
|
b = choose("bass_style", bass_style)
|
|
|
g = choose("guitar_style", guitar_style)
|
|
|
|
|
|
var_list = s.get("variations", [])
|
|
|
variation = ""
|
|
|
if isinstance(var_list, list) and var_list:
|
|
|
if chunk_num == 1:
|
|
|
variation = random.choice(var_list[: max(1, len(var_list)//2)])
|
|
|
else:
|
|
|
variation = random.choice(var_list)
|
|
|
|
|
|
|
|
|
fields: Dict[str, Any] = {}
|
|
|
for k, v in s.items():
|
|
|
fields[k] = self._pick_from_list(v) if isinstance(v, list) else v
|
|
|
|
|
|
|
|
|
if "structure" in s:
|
|
|
fields["section"] = self._pick_from_list(s["structure"])
|
|
|
|
|
|
fields.update({
|
|
|
"bpm": final_bpm,
|
|
|
"chunk": chunk_num,
|
|
|
"drum": f" {d}" if d else "",
|
|
|
"synth": f" {syn}" if syn else "",
|
|
|
"rhythm": f" {r}" if r else "",
|
|
|
"bass": f" {b}" if b else "",
|
|
|
"guitar": f" {g}" if g else "",
|
|
|
"variation": variation
|
|
|
})
|
|
|
|
|
|
tpl = s.get(
|
|
|
"prompt_template",
|
|
|
"Instrumental track at {bpm} BPM {variation}. {mood} {section} {drum}{bass}{guitar}{synth}{rhythm}"
|
|
|
)
|
|
|
|
|
|
prompt = tpl.format_map(SafeFormatDict(fields))
|
|
|
prompt = re.sub(r"\s{2,}", " ", prompt).strip()
|
|
|
return prompt
|
|
|
|
|
|
def style_defaults_for_ui(self, style: str) -> Dict[str, Any]:
|
|
|
self.maybe_reload()
|
|
|
s = self.styles.get(style, {})
|
|
|
bpm_min = int(s.get("bpm_min", "100"))
|
|
|
bpm_max = int(s.get("bpm_max", "140"))
|
|
|
chosen = {
|
|
|
"bpm": random.randint(bpm_min, bpm_max),
|
|
|
"drum_beat": self._pick_from_list(s.get("drum_beat", [])) or "none",
|
|
|
"synthesizer": self._pick_from_list(s.get("synthesizer", [])) or "none",
|
|
|
"rhythmic_steps": self._pick_from_list(s.get("rhythmic_steps", [])) or "none",
|
|
|
"bass_style": self._pick_from_list(s.get("bass_style", [])) or "none",
|
|
|
"guitar_style": self._pick_from_list(s.get("guitar_style", [])) or "none",
|
|
|
}
|
|
|
|
|
|
for k, v in chosen.items():
|
|
|
if v == "":
|
|
|
chosen[k] = "none"
|
|
|
return chosen
|
|
|
|
|
|
STYLES = StylesConfig(PROMPTS_INI)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
from audiocraft.models import MusicGen
|
|
|
except Exception as e:
|
|
|
logger.error("audiocraft is required. pip install audiocraft")
|
|
|
raise
|
|
|
|
|
|
def load_model():
|
|
|
free = check_vram()
|
|
|
if free is not None and free < 5000:
|
|
|
logger.warning("Low free VRAM; consider closing other apps.")
|
|
|
clean_memory()
|
|
|
local_model_path = str(BASE_DIR / "models" / "musicgen-large")
|
|
|
if not os.path.exists(local_model_path):
|
|
|
logger.error(f"Model path missing: {local_model_path}")
|
|
|
sys.exit(1)
|
|
|
logger.info("Loading MusicGen (large)...")
|
|
|
with autocast(dtype=torch.float16):
|
|
|
model = MusicGen.get_pretrained(local_model_path, device=DEVICE)
|
|
|
model.set_generation_params(duration=30, two_step_cfg=False)
|
|
|
logger.info("MusicGen loaded.")
|
|
|
return model
|
|
|
|
|
|
musicgen_model = load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _export_torch_to_segment(audio_tensor: torch.Tensor, sample_rate: int, bit_depth_int: int) -> Optional[AudioSegment]:
|
|
|
tmp = f"temp_audio_{int(time.time()*1000)}.wav"
|
|
|
try:
|
|
|
torchaudio.save(tmp, audio_tensor, sample_rate, bits_per_sample=bit_depth_int)
|
|
|
with open(tmp, "rb") as f:
|
|
|
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
|
|
seg = AudioSegment.from_wav(tmp)
|
|
|
mm.close()
|
|
|
return seg
|
|
|
except Exception as e:
|
|
|
logger.error(f"_export_torch_to_segment failed: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
return None
|
|
|
finally:
|
|
|
try:
|
|
|
if os.path.exists(tmp):
|
|
|
os.remove(tmp)
|
|
|
except OSError:
|
|
|
pass
|
|
|
|
|
|
def _crossfade(seg_a: AudioSegment, seg_b: AudioSegment, overlap_ms: int, sr: int, bit_depth_int: int) -> AudioSegment:
|
|
|
try:
|
|
|
seg_a = ensure_stereo(seg_a, sr, seg_a.sample_width)
|
|
|
seg_b = ensure_stereo(seg_b, sr, seg_b.sample_width)
|
|
|
if overlap_ms <= 0 or len(seg_a) < overlap_ms or len(seg_b) < overlap_ms:
|
|
|
return seg_a + seg_b
|
|
|
prev_wav = f"tmp_prev_{int(time.time()*1000)}.wav"
|
|
|
curr_wav = f"tmp_curr_{int(time.time()*1000)}.wav"
|
|
|
try:
|
|
|
seg_a[-overlap_ms:].export(prev_wav, format="wav")
|
|
|
seg_b[:overlap_ms].export(curr_wav, format="wav")
|
|
|
a_audio, sra = torchaudio.load(prev_wav)
|
|
|
b_audio, srb = torchaudio.load(curr_wav)
|
|
|
if sra != sr:
|
|
|
a_audio = torchaudio.functional.resample(a_audio, sra, sr, lowpass_filter_width=64)
|
|
|
if srb != sr:
|
|
|
b_audio = torchaudio.functional.resample(b_audio, srb, sr, lowpass_filter_width=64)
|
|
|
n = min(a_audio.shape[1], b_audio.shape[1])
|
|
|
n = n - (n % 2)
|
|
|
if n <= 0:
|
|
|
return seg_a + seg_b
|
|
|
a = a_audio[:, :n]
|
|
|
b = b_audio[:, :n]
|
|
|
hann = torch.hann_window(n, periodic=False)
|
|
|
fade_in = hann
|
|
|
fade_out = hann.flip(0)
|
|
|
blended = (a * fade_out + b * fade_in).to(torch.float32).clamp(-1.0, 1.0)
|
|
|
scale = (2**23 if bit_depth_int == 24 else 32767)
|
|
|
blended_i = (blended * scale).to(torch.int32 if bit_depth_int == 24 else torch.int16)
|
|
|
tmpx = f"tmp_cross_{int(time.time()*1000)}.wav"
|
|
|
torchaudio.save(tmpx, blended_i, sr, bits_per_sample=bit_depth_int)
|
|
|
blend_seg = AudioSegment.from_wav(tmpx)
|
|
|
blend_seg = ensure_stereo(blend_seg, sr, blend_seg.sample_width)
|
|
|
result = seg_a[:-overlap_ms] + blend_seg + seg_b[overlap_ms:]
|
|
|
try:
|
|
|
if os.path.exists(tmpx):
|
|
|
os.remove(tmpx)
|
|
|
except OSError:
|
|
|
pass
|
|
|
return result
|
|
|
finally:
|
|
|
for p in [prev_wav, curr_wav]:
|
|
|
try:
|
|
|
if os.path.exists(p):
|
|
|
os.remove(p)
|
|
|
except OSError:
|
|
|
pass
|
|
|
except Exception as e:
|
|
|
logger.error(f"_crossfade failed: {e}")
|
|
|
return seg_a + seg_b
|
|
|
|
|
|
def generate_music(
|
|
|
instrumental_prompt: str,
|
|
|
cfg_scale: float,
|
|
|
top_k: int,
|
|
|
top_p: float,
|
|
|
temperature: float,
|
|
|
total_duration: int,
|
|
|
bpm: int,
|
|
|
drum_beat: str,
|
|
|
synthesizer: str,
|
|
|
rhythmic_steps: str,
|
|
|
bass_style: str,
|
|
|
guitar_style: str,
|
|
|
target_volume: float,
|
|
|
preset: str,
|
|
|
max_steps: str,
|
|
|
vram_status_text: str,
|
|
|
bitrate: str,
|
|
|
output_sample_rate: str,
|
|
|
bit_depth: str
|
|
|
) -> Tuple[Optional[str], str, str]:
|
|
|
|
|
|
if not instrumental_prompt.strip():
|
|
|
return None, "β οΈ Enter a prompt.", vram_status_text
|
|
|
|
|
|
try:
|
|
|
out_sr = int(output_sample_rate)
|
|
|
except:
|
|
|
return None, "β Invalid sample rate.", vram_status_text
|
|
|
try:
|
|
|
bd = int(bit_depth)
|
|
|
sample_width = 3 if bd == 24 else 2
|
|
|
except:
|
|
|
return None, "β Invalid bit depth.", vram_status_text
|
|
|
if not check_disk_space():
|
|
|
return None, "β οΈ Low disk space (<1GB).", vram_status_text
|
|
|
|
|
|
CHUNK_SEC = 30
|
|
|
total_duration = max(30, min(int(total_duration), 120))
|
|
|
num_chunks = math.ceil(total_duration / CHUNK_SEC)
|
|
|
|
|
|
PROCESS_SR = 48000
|
|
|
OVERLAP_SEC = 0.20
|
|
|
seed = random.randint(0, 2**31 - 1)
|
|
|
random.seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
|
np.random.seed(seed)
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
musicgen_model.set_generation_params(
|
|
|
duration=CHUNK_SEC,
|
|
|
use_sampling=True,
|
|
|
top_k=int(top_k),
|
|
|
top_p=float(top_p),
|
|
|
temperature=float(temperature),
|
|
|
cfg_coef=float(cfg_scale),
|
|
|
two_step_cfg=False,
|
|
|
)
|
|
|
|
|
|
vram_status_text = f"Start VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
|
|
|
segments: List[AudioSegment] = []
|
|
|
start_time = time.time()
|
|
|
|
|
|
for idx in range(num_chunks):
|
|
|
chunk_idx = idx + 1
|
|
|
dur = CHUNK_SEC if (idx < num_chunks - 1) else (total_duration - CHUNK_SEC * (num_chunks - 1) or CHUNK_SEC)
|
|
|
logger.info(f"Generating chunk {chunk_idx}/{num_chunks} ({dur}s)")
|
|
|
|
|
|
try:
|
|
|
with torch.no_grad():
|
|
|
with autocast(dtype=torch.float16):
|
|
|
clean_memory()
|
|
|
if idx == 0:
|
|
|
audio = musicgen_model.generate([instrumental_prompt], progress=True)[0].cpu()
|
|
|
else:
|
|
|
prev_seg = segments[-1]
|
|
|
prev_seg = apply_noise_gate(prev_seg, threshold_db=-80, sample_rate=PROCESS_SR)
|
|
|
prev_seg = balance_stereo(prev_seg, noise_threshold=-40, sample_rate=PROCESS_SR)
|
|
|
tmp_prev = f"prev_{int(time.time()*1000)}.wav"
|
|
|
try:
|
|
|
prev_seg.export(tmp_prev, format="wav")
|
|
|
prev_audio, prev_sr = torchaudio.load(tmp_prev)
|
|
|
if prev_sr != PROCESS_SR:
|
|
|
prev_audio = torchaudio.functional.resample(prev_audio, prev_sr, PROCESS_SR, lowpass_filter_width=64)
|
|
|
if prev_audio.shape[0] != 2:
|
|
|
prev_audio = prev_audio.repeat(2, 1)[:, :prev_audio.shape[1]]
|
|
|
prev_audio = prev_audio.to(DEVICE)
|
|
|
tail = prev_audio[:, -int(PROCESS_SR * OVERLAP_SEC):]
|
|
|
audio = musicgen_model.generate_continuation(
|
|
|
prompt=tail,
|
|
|
prompt_sample_rate=PROCESS_SR,
|
|
|
descriptions=[instrumental_prompt],
|
|
|
progress=True
|
|
|
)[0].cpu()
|
|
|
del prev_audio, tail
|
|
|
finally:
|
|
|
try:
|
|
|
if os.path.exists(tmp_prev):
|
|
|
os.remove(tmp_prev)
|
|
|
except OSError:
|
|
|
pass
|
|
|
clean_memory()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Chunk {chunk_idx} generation failed: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
return None, f"β Generate failed at chunk {chunk_idx}.", vram_status_text
|
|
|
|
|
|
try:
|
|
|
if audio.shape[0] != 2:
|
|
|
audio = audio.repeat(2, 1)[:, :audio.shape[1]]
|
|
|
audio = audio.to(dtype=torch.float32)
|
|
|
audio = torchaudio.functional.resample(audio, 32000, PROCESS_SR, lowpass_filter_width=64)
|
|
|
seg = _export_torch_to_segment(audio, PROCESS_SR, bd)
|
|
|
if seg is None:
|
|
|
return None, f"β Convert failed chunk {chunk_idx}.", vram_status_text
|
|
|
seg = ensure_stereo(seg, PROCESS_SR, sample_width)
|
|
|
seg = seg - 15
|
|
|
seg = apply_noise_gate(seg, threshold_db=-80, sample_rate=PROCESS_SR)
|
|
|
seg = balance_stereo(seg, noise_threshold=-40, sample_rate=PROCESS_SR)
|
|
|
seg = rms_normalize(seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR)
|
|
|
seg = apply_eq(seg, sample_rate=PROCESS_SR)
|
|
|
seg = seg[:dur * 1000]
|
|
|
segments.append(seg)
|
|
|
del audio
|
|
|
clean_memory()
|
|
|
vram_status_text = f"VRAM after chunk {chunk_idx}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
|
|
|
except Exception as e:
|
|
|
logger.error(f"Post-process failed chunk {chunk_idx}: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
return None, f"β Post-process failed chunk {chunk_idx}.", vram_status_text
|
|
|
|
|
|
if not segments:
|
|
|
return None, "β No audio generated.", vram_status_text
|
|
|
|
|
|
logger.info("Combining chunks...")
|
|
|
final_seg = segments[0]
|
|
|
overlap_ms = int(OVERLAP_SEC * 1000)
|
|
|
for i in range(1, len(segments)):
|
|
|
final_seg = _crossfade(final_seg, segments[i], overlap_ms, PROCESS_SR, bd)
|
|
|
|
|
|
final_seg = final_seg[:total_duration * 1000]
|
|
|
final_seg = apply_noise_gate(final_seg, threshold_db=-80, sample_rate=PROCESS_SR)
|
|
|
final_seg = balance_stereo(final_seg, noise_threshold=-40, sample_rate=PROCESS_SR)
|
|
|
final_seg = rms_normalize(final_seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR)
|
|
|
final_seg = apply_eq(final_seg, sample_rate=PROCESS_SR)
|
|
|
final_seg = apply_fade(final_seg, 500, 800)
|
|
|
final_seg = final_seg - 10
|
|
|
final_seg = final_seg.set_frame_rate(out_sr)
|
|
|
|
|
|
fname = f"ghostai_{int(time.time())}.mp3"
|
|
|
mp3_path = str(MP3_DIR / fname)
|
|
|
try:
|
|
|
clean_memory()
|
|
|
final_seg.export(mp3_path, format="mp3", bitrate=bitrate,
|
|
|
tags={"title": "GhostAI Instrumental", "artist": "GhostAI"})
|
|
|
except Exception as e:
|
|
|
logger.error(f"MP3 export failed: {e}")
|
|
|
fb = str(MP3_DIR / f"ghostai_fb_{int(time.time())}.mp3")
|
|
|
try:
|
|
|
final_seg.export(fb, format="mp3", bitrate="128k")
|
|
|
mp3_path = fb
|
|
|
except Exception as ee:
|
|
|
return None, f"β Export failed: {ee}", vram_status_text
|
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
vram_status_text = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
|
|
|
logger.info(f"Done in {elapsed:.2f}s -> {mp3_path}")
|
|
|
return mp3_path, "β
Generated.", vram_status_text
|
|
|
|
|
|
def generate_music_wrapper(*args):
|
|
|
try:
|
|
|
return generate_music(*args)
|
|
|
finally:
|
|
|
clean_memory()
|
|
|
|
|
|
def clear_inputs():
|
|
|
s = DEFAULT_SETTINGS.copy()
|
|
|
return (
|
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"],
|
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"],
|
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"],
|
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"]
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BUSY_LOCK = threading.Lock()
|
|
|
BUSY_FLAG = False
|
|
|
BUSY_FILE = "/tmp/musicgen_busy.lock"
|
|
|
CURRENT_JOB: Dict[str, Any] = {"id": None, "start": None}
|
|
|
|
|
|
def set_busy(val: bool, job_id: Optional[str] = None):
|
|
|
global BUSY_FLAG, CURRENT_JOB
|
|
|
with BUSY_LOCK:
|
|
|
BUSY_FLAG = val
|
|
|
if val:
|
|
|
CURRENT_JOB["id"] = job_id or f"job_{int(time.time())}"
|
|
|
CURRENT_JOB["start"] = time.time()
|
|
|
try:
|
|
|
Path(BUSY_FILE).write_text(CURRENT_JOB["id"])
|
|
|
except Exception:
|
|
|
pass
|
|
|
else:
|
|
|
CURRENT_JOB["id"] = None
|
|
|
CURRENT_JOB["start"] = None
|
|
|
try:
|
|
|
if os.path.exists(BUSY_FILE):
|
|
|
os.remove(BUSY_FILE)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
def is_busy() -> bool:
|
|
|
with BUSY_LOCK:
|
|
|
return BUSY_FLAG
|
|
|
|
|
|
def job_elapsed() -> float:
|
|
|
with BUSY_LOCK:
|
|
|
if CURRENT_JOB["start"] is None:
|
|
|
return 0.0
|
|
|
return time.time() - CURRENT_JOB["start"]
|
|
|
|
|
|
class RenderRequest(BaseModel):
|
|
|
instrumental_prompt: str
|
|
|
cfg_scale: Optional[float] = None
|
|
|
top_k: Optional[int] = None
|
|
|
top_p: Optional[float] = None
|
|
|
temperature: Optional[float] = None
|
|
|
total_duration: Optional[int] = None
|
|
|
bpm: Optional[int] = None
|
|
|
drum_beat: Optional[str] = None
|
|
|
synthesizer: Optional[str] = None
|
|
|
rhythmic_steps: Optional[str] = None
|
|
|
bass_style: Optional[str] = None
|
|
|
guitar_style: Optional[str] = None
|
|
|
target_volume: Optional[float] = None
|
|
|
preset: Optional[str] = None
|
|
|
max_steps: Optional[int] = None
|
|
|
bitrate: Optional[str] = None
|
|
|
output_sample_rate: Optional[str] = None
|
|
|
bit_depth: Optional[str] = None
|
|
|
|
|
|
fastapp = FastAPI(title=f"GhostAI Music Server {RELEASE}", version=RELEASE)
|
|
|
fastapp.add_middleware(
|
|
|
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
|
|
|
)
|
|
|
|
|
|
@fastapp.get("/health")
|
|
|
def health():
|
|
|
return {"ok": True, "ts": int(time.time()), "release": RELEASE}
|
|
|
|
|
|
@fastapp.get("/status")
|
|
|
def status():
|
|
|
return {"busy": is_busy(), "job_id": CURRENT_JOB["id"], "since": CURRENT_JOB["start"], "elapsed": job_elapsed()}
|
|
|
|
|
|
@fastapp.get("/styles")
|
|
|
def styles():
|
|
|
return {"styles": STYLES.list_styles()}
|
|
|
|
|
|
@fastapp.get("/prompt/{style}")
|
|
|
def prompt(style: str, bpm: int = 120, chunk: int = 1,
|
|
|
drum_beat: str = "none", synthesizer: str = "none", rhythmic_steps: str = "none",
|
|
|
bass_style: str = "none", guitar_style: str = "none"):
|
|
|
txt = STYLES.build_prompt(style, bpm, chunk, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style)
|
|
|
if not txt:
|
|
|
raise HTTPException(status_code=404, detail="Style not found")
|
|
|
return {"style": style, "prompt": txt}
|
|
|
|
|
|
|
|
|
for sec, cfg in list(STYLES.styles.items()):
|
|
|
api_name = cfg.get("api_name")
|
|
|
if api_name:
|
|
|
route_path = api_name
|
|
|
def make_route(sname, route_path_):
|
|
|
@fastapp.get(route_path_)
|
|
|
def _(bpm: int = 120, chunk: int = 1,
|
|
|
drum_beat: str = "none", synthesizer: str = "none", rhythmic_steps: str = "none",
|
|
|
bass_style: str = "none", guitar_style: str = "none"):
|
|
|
txt = STYLES.build_prompt(sname, bpm, chunk, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style)
|
|
|
if not txt:
|
|
|
raise HTTPException(status_code=404, detail="Style not found")
|
|
|
return {"style": sname, "prompt": txt}
|
|
|
make_route(sec, route_path)
|
|
|
|
|
|
@fastapp.get("/config")
|
|
|
def get_config():
|
|
|
return {"defaults": CURRENT_SETTINGS, "release": RELEASE}
|
|
|
|
|
|
@fastapp.post("/settings")
|
|
|
def set_settings(payload: Dict[str, Any]):
|
|
|
try:
|
|
|
s = CURRENT_SETTINGS.copy()
|
|
|
s.update(payload or {})
|
|
|
save_settings(s)
|
|
|
for k, v in s.items():
|
|
|
CURRENT_SETTINGS[k] = v
|
|
|
return {"ok": True, "saved": s}
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@fastapp.post("/render")
|
|
|
def render(req: RenderRequest):
|
|
|
if is_busy():
|
|
|
raise HTTPException(status_code=409, detail="Server busy")
|
|
|
job_id = f"render_{int(time.time())}"
|
|
|
set_busy(True, job_id)
|
|
|
try:
|
|
|
s = CURRENT_SETTINGS.copy()
|
|
|
for k, v in req.dict().items():
|
|
|
if v is not None:
|
|
|
s[k] = v
|
|
|
mp3, msg, vram = generate_music(
|
|
|
s.get("instrumental_prompt", req.instrumental_prompt),
|
|
|
float(s.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])),
|
|
|
int(s.get("top_k", DEFAULT_SETTINGS["top_k"])),
|
|
|
float(s.get("top_p", DEFAULT_SETTINGS["top_p"])),
|
|
|
float(s.get("temperature", DEFAULT_SETTINGS["temperature"])),
|
|
|
int(s.get("total_duration", DEFAULT_SETTINGS["total_duration"])),
|
|
|
int(s.get("bpm", DEFAULT_SETTINGS["bpm"])),
|
|
|
str(s.get("drum_beat", DEFAULT_SETTINGS["drum_beat"])),
|
|
|
str(s.get("synthesizer", DEFAULT_SETTINGS["synthesizer"])),
|
|
|
str(s.get("rhythmic_steps", DEFAULT_SETTINGS["rhythmic_steps"])),
|
|
|
str(s.get("bass_style", DEFAULT_SETTINGS["bass_style"])),
|
|
|
str(s.get("guitar_style", DEFAULT_SETTINGS["guitar_style"])),
|
|
|
float(s.get("target_volume", DEFAULT_SETTINGS["target_volume"])),
|
|
|
str(s.get("preset", DEFAULT_SETTINGS["preset"])),
|
|
|
str(s.get("max_steps", DEFAULT_SETTINGS["max_steps"])),
|
|
|
"",
|
|
|
str(s.get("bitrate", DEFAULT_SETTINGS["bitrate"])),
|
|
|
str(s.get("output_sample_rate", DEFAULT_SETTINGS["output_sample_rate"])),
|
|
|
str(s.get("bit_depth", DEFAULT_SETTINGS["bit_depth"]))
|
|
|
)
|
|
|
if not mp3:
|
|
|
raise HTTPException(status_code=500, detail=msg)
|
|
|
return {"ok": True, "job_id": job_id, "path": mp3, "status": msg, "vram": vram, "release": RELEASE}
|
|
|
finally:
|
|
|
set_busy(False, None)
|
|
|
|
|
|
def _start_fastapi():
|
|
|
uvicorn.run(fastapp, host="0.0.0.0", port=8555, log_level="info")
|
|
|
|
|
|
api_thread = threading.Thread(target=_start_fastapi, daemon=True)
|
|
|
api_thread.start()
|
|
|
logger.info(f"FastAPI server started on http://0.0.0.0:8555 [{RELEASE}]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_css() -> str:
|
|
|
try:
|
|
|
if CSS_FILE.exists():
|
|
|
return CSS_FILE.read_text(encoding="utf-8")
|
|
|
return """
|
|
|
:root { color-scheme: dark; }
|
|
|
body, .gradio-container { background: #0E1014 !important; color: #FFFFFF !important; }
|
|
|
* { color: #FFFFFF !important; }
|
|
|
input, textarea, select {
|
|
|
background: #151922 !important; color: #FFFFFF !important;
|
|
|
border: 1px solid #2A3142 !important; border-radius: 10px !important;
|
|
|
}
|
|
|
.ga-header { display:flex; gap:12px; align-items:center; }
|
|
|
.ga-header .logo { font-size: 28px; }
|
|
|
"""
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to read CSS: {e}")
|
|
|
return ""
|
|
|
|
|
|
def read_examples() -> str:
|
|
|
try:
|
|
|
return EXAMPLES_MD.read_text(encoding="utf-8")
|
|
|
except Exception:
|
|
|
return "# GhostAI Examples\n\n_Provide examples.md next to app.py_"
|
|
|
|
|
|
loaded = CURRENT_SETTINGS
|
|
|
|
|
|
with gr.Blocks(css=read_css(), analytics_enabled=False, title=f"GhostAI Music Generator {RELEASE}") as demo:
|
|
|
with gr.Tabs():
|
|
|
with gr.Tab(f"ποΈ Generator β {RELEASE}"):
|
|
|
gr.Markdown(f"""
|
|
|
<div class="ga-header" role="banner" aria-label="GhostAI Music Generator">
|
|
|
<div class="logo">π»</div>
|
|
|
<h1>GhostAI Music Generator</h1>
|
|
|
<p>Unified 30s chunking Β· 60β120s ready Β· API & status</p>
|
|
|
</div>
|
|
|
""")
|
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="ga-section"):
|
|
|
gr.Markdown("### Prompt")
|
|
|
instrumental_prompt = gr.Textbox(
|
|
|
label="Instrumental Prompt",
|
|
|
placeholder="Type a prompt or click a style button below",
|
|
|
lines=4,
|
|
|
value=loaded.get("instrumental_prompt", "")
|
|
|
)
|
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="ga-section"):
|
|
|
gr.Markdown("### Band / Style (grid 4 per row)")
|
|
|
def row_of_buttons(entries):
|
|
|
with gr.Row(equal_height=True):
|
|
|
buttons = []
|
|
|
for key, label in entries:
|
|
|
btn = gr.Button(label, variant="secondary", scale=1, min_width=0)
|
|
|
buttons.append((key, btn))
|
|
|
return buttons
|
|
|
|
|
|
row1 = row_of_buttons([
|
|
|
("metallica", "Metallica (Thrash) πΈ"),
|
|
|
("nirvana", "Nirvana (Grunge) π€"),
|
|
|
("pearl_jam", "Pearl Jam (Grunge) π¦ͺ"),
|
|
|
("soundgarden", "Soundgarden (Grunge/Alt Metal) π"),
|
|
|
])
|
|
|
row2 = row_of_buttons([
|
|
|
("foo_fighters", "Foo Fighters (Alt Rock) π€"),
|
|
|
("red_hot_chili_peppers", "Red Hot Chili Peppers (Funk Rock) πΆοΈ"),
|
|
|
("smashing_pumpkins", "Smashing Pumpkins (Alt) π"),
|
|
|
("radiohead", "Radiohead (Experimental) π§ "),
|
|
|
])
|
|
|
row3 = row_of_buttons([
|
|
|
("alternative_rock", "Alternative Rock (Pixies) π΅"),
|
|
|
("post_punk", "Post-Punk (Joy Division) π€"),
|
|
|
("indie_rock", "Indie Rock (Arctic Monkeys) π€"),
|
|
|
("funk_rock", "Funk Rock (RATM) πΊ"),
|
|
|
])
|
|
|
row4 = row_of_buttons([
|
|
|
("detroit_techno", "Detroit Techno ποΈ"),
|
|
|
("deep_house", "Deep House π "),
|
|
|
("classical_star_wars", "Classical (Star Wars Suite) β¨"),
|
|
|
("foo_pad", "β")
|
|
|
])
|
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="ga-section"):
|
|
|
gr.Markdown("### Settings")
|
|
|
with gr.Group():
|
|
|
with gr.Row():
|
|
|
cfg_scale = gr.Slider(1.0, 10.0, step=0.1, value=float(loaded.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])), label="CFG Scale")
|
|
|
top_k = gr.Slider(10, 500, step=10, value=int(loaded.get("top_k", DEFAULT_SETTINGS["top_k"])), label="Top-K")
|
|
|
top_p = gr.Slider(0.0, 1.0, step=0.01, value=float(loaded.get("top_p", DEFAULT_SETTINGS["top_p"])), label="Top-P")
|
|
|
temperature = gr.Slider(0.1, 2.0, step=0.01, value=float(loaded.get("temperature", DEFAULT_SETTINGS["temperature"])), label="Temperature")
|
|
|
with gr.Row():
|
|
|
total_duration = gr.Dropdown(choices=[30, 60, 90, 120], value=int(loaded.get("total_duration", 60)), label="Song Length (seconds)")
|
|
|
bpm = gr.Slider(60, 180, step=1, value=int(loaded.get("bpm", 120)), label="Tempo (BPM)")
|
|
|
target_volume = gr.Slider(-30.0, -20.0, step=0.5, value=float(loaded.get("target_volume", -23.0)), label="Target Loudness (dBFS RMS)")
|
|
|
preset = gr.Dropdown(choices=["default", "rock", "techno", "grunge", "indie", "funk_rock"], value=str(loaded.get("preset", "default")), label="Preset")
|
|
|
with gr.Row():
|
|
|
drum_beat = gr.Dropdown(choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing", "four-on-the-floor", "steady kick", "orchestral percussion", "precise drums", "heavy drums"], value=str(loaded.get("drum_beat", "none")), label="Drum Beat")
|
|
|
synthesizer = gr.Dropdown(choices=["none", "analog synth", "digital pad", "arpeggiated synth", "lush synths", "atmospheric synths", "pulsing synths", "analog pad", "warm synths"], value=str(loaded.get("synthesizer", "none")), label="Synthesizer")
|
|
|
rhythmic_steps = gr.Dropdown(choices=["none", "syncopated steps", "steady steps", "complex steps", "martial march", "staccato ostinato", "triplet swells"], value=str(loaded.get("rhythmic_steps", "none")), label="Rhythmic Steps")
|
|
|
with gr.Row():
|
|
|
bass_style = gr.Dropdown(choices=["none", "slap bass", "deep bass", "melodic bass", "groovy bass", "hypnotic bass", "driving bass", "low brass", "cellos", "double basses", "subby bass"], value=str(loaded.get("bass_style", "none")), label="Bass Style")
|
|
|
guitar_style = gr.Dropdown(choices=["none", "distorted", "clean", "jangle", "downpicked", "thrash riffing", "dreamy", "experimental", "funky"], value=str(loaded.get("guitar_style", "none")), label="Guitar Style")
|
|
|
max_steps = gr.Dropdown(choices=[1000, 1200, 1300, 1500], value=int(loaded.get("max_steps", 1500)), label="Max Steps (hint)")
|
|
|
|
|
|
bitrate_state = gr.State(value=str(loaded.get("bitrate", "192k")))
|
|
|
sample_rate_state = gr.State(value=str(loaded.get("output_sample_rate", "48000")))
|
|
|
bit_depth_state = gr.State(value=str(loaded.get("bit_depth", "16")))
|
|
|
|
|
|
with gr.Row():
|
|
|
bitrate_128_btn = gr.Button("Bitrate 128k", variant="secondary")
|
|
|
bitrate_192_btn = gr.Button("Bitrate 192k", variant="secondary")
|
|
|
bitrate_320_btn = gr.Button("Bitrate 320k", variant="secondary")
|
|
|
sample_rate_22050_btn = gr.Button("SR 22.05k", variant="secondary")
|
|
|
sample_rate_44100_btn = gr.Button("SR 44.1k", variant="secondary")
|
|
|
sample_rate_48000_btn = gr.Button("SR 48k", variant="secondary")
|
|
|
bit_depth_16_btn = gr.Button("16-bit", variant="secondary")
|
|
|
bit_depth_24_btn = gr.Button("24-bit", variant="secondary")
|
|
|
|
|
|
with gr.Row():
|
|
|
gen_btn = gr.Button("Generate πΆ", variant="primary")
|
|
|
clr_btn = gr.Button("Clear π§Ή", variant="secondary")
|
|
|
save_btn = gr.Button("Save Settings πΎ", variant="secondary")
|
|
|
load_btn = gr.Button("Load Settings π", variant="secondary")
|
|
|
reset_btn = gr.Button("Reset Defaults β»οΈ", variant="secondary")
|
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="ga-section"):
|
|
|
gr.Markdown("### Output")
|
|
|
out_audio = gr.Audio(label="Generated Track", type="filepath")
|
|
|
status_box = gr.Textbox(label="Status", interactive=False)
|
|
|
vram_box = gr.Textbox(label="VRAM", interactive=False, value="")
|
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="ga-section"):
|
|
|
gr.Markdown("### Logs")
|
|
|
log_output = gr.Textbox(label="Current Log (rotating β€ 5MB)", lines=14, interactive=False)
|
|
|
log_btn = gr.Button("View Log π", variant="secondary")
|
|
|
|
|
|
with gr.Tab("π Info & Examples"):
|
|
|
md_box = gr.Markdown(read_examples())
|
|
|
refresh_md = gr.Button("Refresh Examples.md", variant="secondary")
|
|
|
refresh_md.click(lambda: read_examples(), outputs=md_box)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_prompt_and_settings_from_style(style_key, current_bpm, current_drum, current_synth, current_steps, current_bass, current_guitar):
|
|
|
|
|
|
defaults = STYLES.style_defaults_for_ui(style_key)
|
|
|
|
|
|
new_bpm = int(defaults.get("bpm", current_bpm or 120))
|
|
|
new_drum = str(defaults.get("drum_beat", "none"))
|
|
|
new_synth = str(defaults.get("synthesizer", "none"))
|
|
|
new_steps = str(defaults.get("rhythmic_steps", "none"))
|
|
|
new_bass = str(defaults.get("bass_style", "none"))
|
|
|
new_guitar = str(defaults.get("guitar_style", "none"))
|
|
|
|
|
|
|
|
|
prompt_txt = STYLES.build_prompt(
|
|
|
style_key,
|
|
|
new_bpm,
|
|
|
1,
|
|
|
new_drum,
|
|
|
new_synth,
|
|
|
new_steps,
|
|
|
new_bass,
|
|
|
new_guitar
|
|
|
)
|
|
|
if not prompt_txt:
|
|
|
prompt_txt = f"{style_key}: update prompts.ini"
|
|
|
|
|
|
return (
|
|
|
prompt_txt,
|
|
|
new_bpm,
|
|
|
new_drum,
|
|
|
new_synth,
|
|
|
new_steps,
|
|
|
new_bass,
|
|
|
new_guitar
|
|
|
)
|
|
|
|
|
|
|
|
|
for key, btn in row1 + row2 + row3 + row4:
|
|
|
if key == "foo_pad":
|
|
|
continue
|
|
|
btn.click(
|
|
|
set_prompt_and_settings_from_style,
|
|
|
inputs=[gr.State(key), bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style],
|
|
|
outputs=[instrumental_prompt, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style]
|
|
|
)
|
|
|
|
|
|
|
|
|
bitrate_128_btn.click(lambda: "128k", outputs=bitrate_state)
|
|
|
bitrate_192_btn.click(lambda: "192k", outputs=bitrate_state)
|
|
|
bitrate_320_btn.click(lambda: "320k", outputs=bitrate_state)
|
|
|
sample_rate_22050_btn.click(lambda: "22050", outputs=sample_rate_state)
|
|
|
sample_rate_44100_btn.click(lambda: "44100", outputs=sample_rate_state)
|
|
|
sample_rate_48000_btn.click(lambda: "48000", outputs=sample_rate_state)
|
|
|
bit_depth_16_btn.click(lambda: "16", outputs=bit_depth_state)
|
|
|
bit_depth_24_btn.click(lambda: "24", outputs=bit_depth_state)
|
|
|
|
|
|
|
|
|
gen_btn.click(
|
|
|
generate_music_wrapper,
|
|
|
inputs=[
|
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
|
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
|
|
|
preset, max_steps, vram_box, bitrate_state, sample_rate_state, bit_depth_state
|
|
|
],
|
|
|
outputs=[out_audio, status_box, vram_box]
|
|
|
)
|
|
|
|
|
|
|
|
|
clr_btn.click(
|
|
|
clear_inputs, outputs=[
|
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
|
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
|
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state
|
|
|
]
|
|
|
)
|
|
|
|
|
|
|
|
|
def _save_action(
|
|
|
instrumental_prompt_v, cfg_v, top_k_v, top_p_v, temp_v, dur_v, bpm_v,
|
|
|
drum_v, synth_v, steps_v, bass_v, guitar_v, vol_v, preset_v, maxsteps_v, br_v, sr_v, bd_v
|
|
|
):
|
|
|
s = {
|
|
|
"instrumental_prompt": instrumental_prompt_v,
|
|
|
"cfg_scale": float(cfg_v),
|
|
|
"top_k": int(top_k_v),
|
|
|
"top_p": float(top_p_v),
|
|
|
"temperature": float(temp_v),
|
|
|
"total_duration": int(dur_v),
|
|
|
"bpm": int(bpm_v),
|
|
|
"drum_beat": str(drum_v),
|
|
|
"synthesizer": str(synth_v),
|
|
|
"rhythmic_steps": str(steps_v),
|
|
|
"bass_style": str(bass_v),
|
|
|
"guitar_style": str(guitar_v),
|
|
|
"target_volume": float(vol_v),
|
|
|
"preset": str(preset_v),
|
|
|
"max_steps": int(maxsteps_v),
|
|
|
"bitrate": str(br_v),
|
|
|
"output_sample_rate": str(sr_v),
|
|
|
"bit_depth": str(bd_v)
|
|
|
}
|
|
|
save_settings(s)
|
|
|
for k, v in s.items():
|
|
|
CURRENT_SETTINGS[k] = v
|
|
|
return "β
Settings saved."
|
|
|
|
|
|
def _load_action():
|
|
|
s = load_settings()
|
|
|
for k, v in s.items():
|
|
|
CURRENT_SETTINGS[k] = v
|
|
|
return (
|
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"],
|
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"],
|
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"],
|
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"],
|
|
|
"β
Settings loaded."
|
|
|
)
|
|
|
|
|
|
def _reset_action():
|
|
|
s = DEFAULT_SETTINGS.copy()
|
|
|
save_settings(s)
|
|
|
for k, v in s.items():
|
|
|
CURRENT_SETTINGS[k] = v
|
|
|
return (
|
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"],
|
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"],
|
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"],
|
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"],
|
|
|
"β
Defaults restored."
|
|
|
)
|
|
|
|
|
|
save_btn.click(
|
|
|
_save_action,
|
|
|
inputs=[
|
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
|
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
|
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state
|
|
|
],
|
|
|
outputs=status_box
|
|
|
)
|
|
|
|
|
|
load_btn.click(
|
|
|
_load_action,
|
|
|
outputs=[
|
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
|
|
|
drum_eat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
|
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box
|
|
|
] if False else [
|
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
|
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
|
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box
|
|
|
]
|
|
|
)
|
|
|
|
|
|
reset_btn.click(
|
|
|
_reset_action,
|
|
|
outputs=[
|
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm,
|
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume,
|
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, status_box
|
|
|
]
|
|
|
)
|
|
|
|
|
|
def _get_log():
|
|
|
try:
|
|
|
return LOG_FILE.read_text(encoding="utf-8")[-40000:]
|
|
|
except Exception as e:
|
|
|
return f"Log read error: {e}"
|
|
|
|
|
|
log_btn.click(_get_log, outputs=log_output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print(f"{Fore.CYAN}Launching Gradio UI http://0.0.0.0:9999 [{RELEASE}]{Fore.RESET}")
|
|
|
try:
|
|
|
demo.launch(
|
|
|
server_name="0.0.0.0",
|
|
|
server_port=9999,
|
|
|
share=False,
|
|
|
inbrowser=False,
|
|
|
show_error=True
|
|
|
)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Gradio launch failed: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
sys.exit(1)
|
|
|
|