|
|
import os |
|
|
import torch |
|
|
import torchaudio |
|
|
import psutil |
|
|
import time |
|
|
import sys |
|
|
import numpy as np |
|
|
import gc |
|
|
import gradio as gr |
|
|
from pydub import AudioSegment |
|
|
import soundfile as sf |
|
|
import pyloudnorm as pyln |
|
|
from audiocraft.models import MusicGen |
|
|
from torch.amp import autocast |
|
|
import json |
|
|
import configparser |
|
|
import random |
|
|
import string |
|
|
import uvicorn |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import FileResponse |
|
|
from pydantic import BaseModel |
|
|
import multiprocessing |
|
|
import re |
|
|
import datetime |
|
|
import warnings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
multiprocessing.set_start_method('spawn', force=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["TORCH_NN_UTILS_LOG_LEVEL"] = "0" |
|
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
|
os.environ["CUDA_MODULE_LOADING"] = "LAZY" |
|
|
os.environ["TORCH_USE_CUDA_DSA"] = "1" |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,garbage_collection_threshold:0.8,expandable_segments:True" |
|
|
|
|
|
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.5;8.0;8.6;8.9" |
|
|
|
|
|
|
|
|
try: |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.benchmark = True |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_version_triplet(s: str): |
|
|
m = re.findall(r"\d+", s) |
|
|
m = [int(x) for x in m[:3]] |
|
|
while len(m) < 3: |
|
|
m.append(0) |
|
|
return tuple(m) |
|
|
|
|
|
if _parse_version_triplet(torch.__version__) < (2, 0, 0): |
|
|
print(f"ERROR: PyTorch {torch.__version__} incompatible. Need >=2.0.0.") |
|
|
sys.exit(1) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if device != "cuda": |
|
|
print("ERROR: CUDA required. CPU disabled.") |
|
|
sys.exit(1) |
|
|
|
|
|
cc_major, cc_minor = torch.cuda.get_device_capability(0) |
|
|
if cc_major < 7: |
|
|
print(f"ERROR: GPU Compute Capability {torch.cuda.get_device_capability(0)} unsupported. Need >=7.0.") |
|
|
sys.exit(1) |
|
|
|
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
print(f"Using GPU: {gpu_name} (CUDA {torch.version.cuda}, Compute Capability {(cc_major, cc_minor)})") |
|
|
|
|
|
|
|
|
try: |
|
|
bf16_supported = torch.cuda.is_bf16_supported() |
|
|
except Exception: |
|
|
bf16_supported = False |
|
|
AUTOCAST_DTYPE = torch.bfloat16 if bf16_supported and cc_major >= 8 else torch.float16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_resource_usage(stage: str): |
|
|
try: |
|
|
alloc = torch.cuda.memory_allocated() / (1024 ** 3) |
|
|
reserved = torch.cuda.memory_reserved() / (1024 ** 3) |
|
|
except Exception: |
|
|
alloc, reserved = 0.0, 0.0 |
|
|
print(f"--- {stage} ---") |
|
|
print(f"GPU Memory: {alloc:.2f} GB allocated, {reserved:.2f} GB reserved") |
|
|
print(f"CPU: {psutil.cpu_percent()}% | Memory: {psutil.virtual_memory().percent}%") |
|
|
print("---------------") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_dir = "mp3" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
metadata_file = os.path.join(output_dir, "songs_metadata.json") |
|
|
api_status = "idle" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_variables = { |
|
|
'style': [ |
|
|
'epic', 'gritty', 'smooth', 'lush', 'raw', 'intimate', 'driving', 'moody', |
|
|
'psychedelic', 'uplifting', 'melancholic', 'aggressive', 'dreamy', 'retro', |
|
|
'futuristic', 'energetic', 'brooding', 'euphoric', 'jazzy', 'cinematic', |
|
|
'somber', 'triumphant', 'mystical', 'grunge', 'ethereal' |
|
|
], |
|
|
'key': ['C major', 'D major', 'E minor', 'F minor', 'G major', 'A minor', 'B-flat major', 'G minor', 'D minor', 'F major'], |
|
|
'bpm': [80, 90, 100, 110, 120, 124, 128, 130, 140, 150, 160, 170, 180], |
|
|
'time_signature': ['4/4', '3/4', '6/8'], |
|
|
'guitar_style': [ |
|
|
'raw distorted', 'melodic', 'fuzzy', 'crisp', 'jangly', 'clean', 'twangy', |
|
|
'shimmering', 'grunge', 'bluesy', 'slide', 'wah-infused', 'chunky' |
|
|
], |
|
|
'bass_style': [ |
|
|
'punchy', 'deep', 'groovy', 'melodic', 'throbbing', 'slappy', 'funky', |
|
|
'walking', 'booming', 'resonant', 'subtle' |
|
|
], |
|
|
'drum_style': [ |
|
|
'dynamic', 'minimal', 'hard-hitting', 'swinging', 'polyrhythmic', 'brushed', |
|
|
'tight', 'loose', 'electronic', 'acoustic', 'retro', 'punchy' |
|
|
], |
|
|
'drum_feature': [ |
|
|
'heavy snare', 'crisp cymbals', 'tight kicks', 'syncopated hits', 'rolling toms', |
|
|
'ghost notes', 'blast beats' |
|
|
], |
|
|
'organ_style': [ |
|
|
'subtle Hammond', 'swirling', 'warm Leslie', 'church', 'gritty', 'vintage', |
|
|
'moody' |
|
|
], |
|
|
'synth_style': [ |
|
|
'atmospheric', 'bright', 'eerie', 'soaring', 'chopped', 'arpeggiated', |
|
|
'pulsing', 'glitchy', 'analog', 'digital', 'layered' |
|
|
], |
|
|
'vocal_style': [ |
|
|
'chopped', 'soulful', 'haunting', 'melodic', 'harmonized', 'layered', |
|
|
'ethereal', 'gruff', 'breathy' |
|
|
], |
|
|
'hihat_style': [ |
|
|
'crisp', 'swinging', 'rapid', 'shuffling', 'open', 'tight', 'stuttered' |
|
|
], |
|
|
'pad_style': [ |
|
|
'evolving', 'ambient', 'lush', 'dark', 'shimmering', 'warm', 'icy' |
|
|
], |
|
|
'kick_style': [ |
|
|
'deep', 'four-on-the-floor', 'subtle', 'punchy', 'booming', 'clicky' |
|
|
], |
|
|
'lead_style': [ |
|
|
'fluid', 'intricate', 'soaring', 'expressive', 'virtuosic', 'minimalist', |
|
|
'bluesy', 'lyrical' |
|
|
], |
|
|
'lead_instrument': [ |
|
|
'saxophone', 'trumpet', 'guitar', 'flute', 'violin', 'clarinet', 'trombone' |
|
|
], |
|
|
'piano_style': [ |
|
|
'expressive Rhodes', 'rapid', 'smooth', 'dramatic', 'stride', 'ambient', |
|
|
'classical', 'jazzy', 'sparse' |
|
|
], |
|
|
'keyboard_style': [ |
|
|
'ornate', 'delicate', 'virtuosic', 'minimal', 'retro', 'spacey' |
|
|
], |
|
|
'string_style': [ |
|
|
'sweeping', 'delicate', 'dramatic', 'lush', 'pizzicato', 'staccato', |
|
|
'sustained' |
|
|
], |
|
|
'brass_style': [ |
|
|
'bold', 'heroic', 'muted', 'fanfare', 'jazzy', 'smooth' |
|
|
], |
|
|
'woodwind_style': [ |
|
|
'subtle', 'fluttering', 'melodic', 'airy', 'reedy', 'expressive' |
|
|
], |
|
|
'flute_style': [ |
|
|
'fluttering', 'ornate', 'airy', 'breathy', 'trilling' |
|
|
], |
|
|
'horn_style': [ |
|
|
'heroic', 'bold', 'soaring', 'mellow', 'stinging' |
|
|
], |
|
|
'choir_style': [ |
|
|
'mystical', 'ethereal', 'dramatic', 'angelic', 'epic', 'somber' |
|
|
], |
|
|
'sample_style': [ |
|
|
'jazzy', 'soulful', 'gritty', 'cinematic', 'vinyl', 'lo-fi', 'retro' |
|
|
], |
|
|
'scratch_style': [ |
|
|
'crackling vinyl', 'sharp', 'rhythmic', 'chopped', 'transform' |
|
|
], |
|
|
'snare_style': [ |
|
|
'crisp', 'booming', 'tight', 'snappy', 'rimshot', 'layered' |
|
|
], |
|
|
'breakdown_style': [ |
|
|
'euphoric', 'stripped-down', 'intense', 'ambient', 'glitchy', 'dramatic' |
|
|
], |
|
|
'intro_bars': [4, 8, 16], |
|
|
'verse_bars': [8, 16, 32], |
|
|
'chorus_bars': [8, 16], |
|
|
'bridge_bars': [4, 8, 16], |
|
|
'outro_bars': [8, 16], |
|
|
'build_bars': [8, 16, 32], |
|
|
'drop_bars': [16, 32], |
|
|
'main_bars': [16, 32], |
|
|
'breakdown_bars': [8, 16], |
|
|
'head_bars': [16, 32], |
|
|
'solo_bars': [8, 16, 32], |
|
|
'fugue_bars': [16, 32], |
|
|
'coda_bars': [8, 16], |
|
|
'theme_bars': [16, 32], |
|
|
'development_bars': [16, 32], |
|
|
'climax_bars': [8, 16], |
|
|
'groove_bars': [16, 32], |
|
|
'vibe': [ |
|
|
'raw', 'energetic', 'melancholic', 'hypnotic', 'soulful', 'intimate', |
|
|
'virtuosic', 'elegant', 'cinematic', 'gritty', 'nostalgic', 'dark', |
|
|
'uplifting', 'bittersweet', 'heroic', 'dreamy', 'aggressive', 'relaxed', |
|
|
'futuristic', 'retro', 'mystical', 'triumphant' |
|
|
], |
|
|
'production_style': [ |
|
|
'lo-fi', 'warm analog', 'clean digital', 'lush', 'crisp acoustic', |
|
|
'polished pop', 'grand orchestral', 'grunge', 'minimalist', 'industrial', |
|
|
'vintage' |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_default_genre_prompts_ini(ini_path): |
|
|
default_config = configparser.ConfigParser() |
|
|
default_config['Prompts'] = { |
|
|
'nirvana': '{style} grunge with {guitar_style} guitar, {bass_style} bass, {drum_style} drums, {vibe} vibe in {key} at {bpm} BPM', |
|
|
'classic_rock': '{style} classic rock with {guitar_style} guitar, {bass_style} bass, {drum_style} drums, {vibe} vibe in {key} at {bpm} BPM', |
|
|
'detroit_techno': '{style} techno with {synth_style} synths, {kick_style} kick, {hihat_style} hi-hats, {vibe} vibe at {bpm} BPM', |
|
|
'smooth_jazz': '{style} jazz with {piano_style} piano, {bass_style} bass, {drum_style} drums, {vibe} vibe in {key} at {bpm} BPM', |
|
|
'alternative_rock': '{style} alternative rock with {guitar_style} guitar, {bass_style} bass, {drum_style} drums in {key} at {bpm} BPM', |
|
|
'deep_house': '{style} deep house with {synth_style} synths, {kick_style} kick, {vibe} vibe at {bpm} BPM', |
|
|
'bebop_jazz': '{style} bebop jazz with {piano_style} piano, {bass_style} bass, {drum_style} drums in {key} at {bpm} BPM', |
|
|
'baroque_classical': '{style} baroque classical with {string_style} strings, {keyboard_style} harpsichord in {key} at {bpm} BPM', |
|
|
'romantic_classical': '{style} romantic classical with {string_style} strings, {piano_style} piano in {key} at {bpm} BPM', |
|
|
'boom_bap_hiphop': '{style} boom bap hip-hop with {sample_style} samples, {drum_style} drums, {scratch_style} scratches at {bpm} BPM', |
|
|
'trap_hiphop': '{style} trap hip-hop with {synth_style} synths, {kick_style} kick, {snare_style} snare at {bpm} BPM', |
|
|
'pop_rock': '{style} pop rock with {guitar_style} guitar, {bass_style} bass, {drum_style} drums in {key} at {bpm} BPM', |
|
|
'fusion_jazz': '{style} fusion jazz with {piano_style} piano, {guitar_style} guitar, {drum_style} drums in {key} at {bpm} BPM', |
|
|
'edm': '{style} EDM with {synth_style} synths, {kick_style} kick, {vibe} vibe at {bpm} BPM', |
|
|
'indie_folk': '{style} indie folk with {guitar_style} guitar, {vocal_style} vocals, {drum_style} drums in {key} at {bpm} BPM', |
|
|
'star_wars': '{style} epic orchestral with {brass_style} brass, {string_style} strings, {vibe} vibe in {key} at {bpm} BPM', |
|
|
'star_wars_classical': '{style} classical orchestral with {string_style} strings, {horn_style} horns in {key} at {bpm} BPM', |
|
|
'wutang': '{style} hip-hop with {sample_style} samples, {drum_style} drums, {scratch_style} scratches at {bpm} BPM', |
|
|
'milesdavis': '{style} jazz with {lead_instrument} lead, {piano_style} piano, {bass_style} bass in {key} at {bpm} BPM' |
|
|
} |
|
|
default_config['BandNames'] = { |
|
|
'nirvana': 'Nirvana, Soundgarden', |
|
|
'classic_rock': 'Led Zeppelin, The Rolling Stones', |
|
|
'detroit_techno': 'Underground Resistance, Jeff Mills', |
|
|
'smooth_jazz': 'Pat Metheny, George Benson', |
|
|
'alternative_rock': 'Radiohead, Smashing Pumpkins', |
|
|
'deep_house': 'Moodymann, Theo Parrish', |
|
|
'bebop_jazz': 'Charlie Parker, Dizzy Gillespie', |
|
|
'baroque_classical': 'Bach, Vivaldi', |
|
|
'romantic_classical': 'Chopin, Liszt', |
|
|
'boom_bap_hiphop': 'A Tribe Called Quest, Pete Rock', |
|
|
'trap_hiphop': 'Future, Metro Boomin', |
|
|
'pop_rock': 'Coldplay, The Killers', |
|
|
'fusion_jazz': 'Weather Report, Herbie Hancock', |
|
|
'edm': 'Deadmau5, Skrillex', |
|
|
'indie_folk': 'Fleet Foxes, Bon Iver', |
|
|
'star_wars': 'John Williams', |
|
|
'star_wars_classical': 'John Williams', |
|
|
'wutang': 'Wu-Tang Clan', |
|
|
'milesdavis': 'Miles Davis' |
|
|
} |
|
|
with open(ini_path, 'w') as f: |
|
|
default_config.write(f) |
|
|
print(f"Created default {ini_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css_path = "style.css" |
|
|
try: |
|
|
if not os.path.exists(css_path): |
|
|
print(f"ERROR: {css_path} not found. Please create style.css with the required CSS content.") |
|
|
sys.exit(1) |
|
|
with open(css_path, 'r') as f: |
|
|
css = f.read() |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to read {css_path}: {e}. Please ensure style.css exists and is readable.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = configparser.ConfigParser() |
|
|
ini_path = "genre_prompts.ini" |
|
|
try: |
|
|
if not os.path.exists(ini_path): |
|
|
print(f"WARNING: {ini_path} not found. Creating default INI file.") |
|
|
create_default_genre_prompts_ini(ini_path) |
|
|
config.read(ini_path) |
|
|
if 'Prompts' not in config.sections() or 'BandNames' not in config.sections(): |
|
|
print(f"WARNING: Invalid {ini_path}. Creating default INI file.") |
|
|
create_default_genre_prompts_ini(ini_path) |
|
|
config.read(ini_path) |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to read {ini_path}: {e}. Creating default INI file.") |
|
|
create_default_genre_prompts_ini(ini_path) |
|
|
config.read(ini_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_musicgen_with_fallback(): |
|
|
model_paths = [ |
|
|
os.getenv("MUSICGEN_MODEL_PATH_LARGE", "/home/ubuntu/musicpack/models/musicgen-large"), |
|
|
os.getenv("MUSICGEN_MODEL_PATH_MEDIUM", "/home/ubuntu/musicpack/models/musicgen-medium"), |
|
|
os.getenv("MUSICGEN_MODEL_PATH_SMALL", "/home/ubuntu/musicpack/models/musicgen-small"), |
|
|
] |
|
|
model_names = ["large", "medium", "small"] |
|
|
|
|
|
last_error = None |
|
|
for path, name in zip(model_paths, model_names): |
|
|
if not path: |
|
|
continue |
|
|
if not os.path.exists(path): |
|
|
print(f"NOTE: Model path not found: {path} (skipping {name})") |
|
|
continue |
|
|
try: |
|
|
print(f"Loading MusicGen {name} model from {path} ...") |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
with autocast('cuda', dtype=AUTOCAST_DTYPE): |
|
|
mdl = MusicGen.get_pretrained(path, device=device) |
|
|
print(f"Loaded MusicGen {name}. Sample rate: {mdl.sample_rate}Hz") |
|
|
return mdl, name |
|
|
except RuntimeError as e: |
|
|
last_error = e |
|
|
print(f"WARNING: Failed to load {name} model due to: {e}") |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
continue |
|
|
except Exception as e: |
|
|
last_error = e |
|
|
print(f"WARNING: Failed to load {name} model due to: {e}") |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
continue |
|
|
if last_error: |
|
|
print(f"ERROR: All model loads failed. Last error: {last_error}") |
|
|
raise SystemExit(1) |
|
|
|
|
|
try: |
|
|
musicgen_model, loaded_model_name = load_musicgen_with_fallback() |
|
|
|
|
|
musicgen_model.set_generation_params( |
|
|
duration=10, |
|
|
use_sampling=True, |
|
|
top_k=50, |
|
|
top_p=0.0, |
|
|
temperature=0.8, |
|
|
cfg_coef=3.0, |
|
|
two_step_cfg=False |
|
|
) |
|
|
sample_rate = musicgen_model.sample_rate |
|
|
print(f"Model active: {loaded_model_name}. Sample rate: {sample_rate}Hz") |
|
|
except SystemExit: |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_eq(segment): |
|
|
segment = segment.high_pass_filter(60) |
|
|
segment = segment.low_pass_filter(12000) |
|
|
segment = segment - 2.0 |
|
|
return segment |
|
|
|
|
|
def apply_limiter(segment, max_db=-6.0, target_lufs=-16.0): |
|
|
samples = np.array(segment.get_array_of_samples(), dtype=np.float32) / (2**15) |
|
|
if segment.channels == 2: |
|
|
samples = samples.reshape(-1, 2) |
|
|
meter = pyln.Meter(segment.frame_rate) |
|
|
loudness = meter.integrated_loudness(samples) |
|
|
normalized_samples = pyln.normalize.loudness(samples, loudness, target_lufs) |
|
|
if np.max(np.abs(normalized_samples)) > (10 ** (max_db / 20)): |
|
|
normalized_samples *= (10 ** (max_db / 20)) / np.max(np.abs(normalized_samples)) |
|
|
normalized_samples = (normalized_samples * (2**15)).astype(np.int16) |
|
|
segment = AudioSegment( |
|
|
normalized_samples.tobytes(), |
|
|
frame_rate=segment.frame_rate, |
|
|
sample_width=2, |
|
|
channels=segment.channels |
|
|
) |
|
|
del samples, normalized_samples |
|
|
gc.collect() |
|
|
return segment |
|
|
|
|
|
def apply_fade(segment, fade_in_duration=1000, fade_out_duration=1000): |
|
|
segment = segment.fade_in(fade_in_duration) |
|
|
segment = segment.fade_out(fade_out_duration) |
|
|
return segment |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
made_up_names = [ |
|
|
'blazepulse', 'shadowrift', 'neonquest', 'thunderclash', 'stargroove', |
|
|
'mysticvibe', 'ironspark', 'ghostsurge', 'velvetstorm', 'crimsonrush', |
|
|
'duskblitz', 'solarflame', 'nightdrift', 'frostsaga', 'emberwave', |
|
|
'coolriff', 'wildpulse', 'echoslash', 'moontide', 'skydive' |
|
|
] |
|
|
|
|
|
def extract_song_keyword(prompt): |
|
|
if not prompt: |
|
|
return random.choice(made_up_names) |
|
|
words = re.findall(r'\b\w+\b', prompt.lower()) |
|
|
for word in words: |
|
|
if len(word) <= 15 and word.isalnum(): |
|
|
return word |
|
|
return random.choice(made_up_names) |
|
|
|
|
|
def generate_unique_title(existing_titles, genre, song_keyword, style): |
|
|
letters = string.ascii_uppercase |
|
|
numbers = string.digits |
|
|
max_attempts = 100 |
|
|
attempt = 0 |
|
|
while attempt < max_attempts: |
|
|
title_base = f"{random.choice(letters)}{random.choice(numbers)}" |
|
|
band_names = config['BandNames'].get(genre, "nirvana").split(',') |
|
|
band_name = random.choice([name.strip() for name in band_names]) |
|
|
existing_count = sum(1 for t in existing_titles if t.startswith(title_base) and song_keyword in t and style in t and band_name in t) |
|
|
if existing_count == 0: |
|
|
return title_base, band_name |
|
|
suffix = f"{random.choice(letters)}{random.choice(numbers)}".lower() |
|
|
title_base = f"{title_base}_{suffix}" |
|
|
attempt += 1 |
|
|
raise ValueError("Failed to generate unique title after maximum attempts") |
|
|
|
|
|
def update_metadata_storage(metadata): |
|
|
try: |
|
|
songs_metadata = [] |
|
|
if os.path.exists(metadata_file): |
|
|
with open(metadata_file, 'r') as f: |
|
|
songs_metadata = json.load(f) |
|
|
songs_metadata.append({ |
|
|
"title": metadata["title"], |
|
|
"filename": metadata["filename"], |
|
|
"prompt": metadata.get("prompt", ""), |
|
|
"duration": metadata.get("duration", 30), |
|
|
"volume_db": metadata.get("volume_db", -24.0), |
|
|
"target_lufs": metadata.get("target_lufs", -16.0), |
|
|
"timestamp": metadata.get("timestamp", datetime.datetime.now().strftime("%Y%m%d_%H%M%S")), |
|
|
"file_path": metadata.get("file_path", ""), |
|
|
"sample_rate": metadata.get("sample_rate", musicgen_model.sample_rate), |
|
|
"style": metadata.get("style", ""), |
|
|
"band_name": metadata.get("band_name", ""), |
|
|
"chunk_index": metadata.get("chunk_index", 0) |
|
|
}) |
|
|
with open(metadata_file, 'w') as f: |
|
|
json.dump(songs_metadata, f, indent=4) |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to update metadata storage: {e}") |
|
|
|
|
|
def load_renders(): |
|
|
if not os.path.exists(metadata_file): |
|
|
return [], "No renders found." |
|
|
try: |
|
|
with open(metadata_file, 'r') as f: |
|
|
songs_metadata = json.load(f) |
|
|
renders = [ |
|
|
{ |
|
|
"Title": entry["title"], |
|
|
"Filename": entry["filename"], |
|
|
"Prompt": entry["prompt"], |
|
|
"Duration (s)": entry["duration"], |
|
|
"Timestamp": entry["timestamp"], |
|
|
"Audio": entry["file_path"], |
|
|
"Download": f'<a href="/get-song/{entry["filename"]}" download><button class="download-btn" aria-label="Download {entry["title"]}">⬇️</button></a>', |
|
|
"Chunk": entry["chunk_index"] |
|
|
} |
|
|
for entry in songs_metadata |
|
|
] |
|
|
return renders, "Renders loaded successfully." |
|
|
except Exception as e: |
|
|
return [], f"Error loading renders: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_genre_prompt(genre): |
|
|
base_prompt = config['Prompts'].get(genre, "") |
|
|
if not base_prompt: |
|
|
base_prompt = "{style} grunge with {guitar_style} guitar, {bass_style} bass, {drum_style} drums, {vibe} vibe in {key} at {bpm} BPM" |
|
|
prompt_dict = { |
|
|
'style': random.choice(prompt_variables['style']), |
|
|
'key': random.choice(prompt_variables['key']), |
|
|
'bpm': random.choice(prompt_variables['bpm']), |
|
|
'time_signature': random.choice(prompt_variables['time_signature']), |
|
|
'guitar_style': random.choice(prompt_variables['guitar_style']), |
|
|
'bass_style': random.choice(prompt_variables['bass_style']), |
|
|
'drum_style': random.choice(prompt_variables['drum_style']), |
|
|
'drum_feature': random.choice(prompt_variables['drum_feature']), |
|
|
'organ_style': random.choice(prompt_variables['organ_style']), |
|
|
'synth_style': random.choice(prompt_variables['synth_style']), |
|
|
'vocal_style': random.choice(prompt_variables['vocal_style']), |
|
|
'hihat_style': random.choice(prompt_variables['hihat_style']), |
|
|
'pad_style': random.choice(prompt_variables['pad_style']), |
|
|
'kick_style': random.choice(prompt_variables['kick_style']), |
|
|
'lead_style': random.choice(prompt_variables['lead_style']), |
|
|
'lead_instrument': random.choice(prompt_variables['lead_instrument']), |
|
|
'piano_style': random.choice(prompt_variables['piano_style']), |
|
|
'keyboard_style': random.choice(prompt_variables['keyboard_style']), |
|
|
'string_style': random.choice(prompt_variables['string_style']), |
|
|
'brass_style': random.choice(prompt_variables['brass_style']), |
|
|
'woodwind_style': random.choice(prompt_variables['woodwind_style']), |
|
|
'flute_style': random.choice(prompt_variables['flute_style']), |
|
|
'horn_style': random.choice(prompt_variables['horn_style']), |
|
|
'choir_style': random.choice(prompt_variables['choir_style']), |
|
|
'sample_style': random.choice(prompt_variables['sample_style']), |
|
|
'scratch_style': random.choice(prompt_variables['scratch_style']), |
|
|
'snare_style': random.choice(prompt_variables['snare_style']), |
|
|
'breakdown_style': random.choice(prompt_variables['breakdown_style']), |
|
|
'intro_bars': random.choice(prompt_variables['intro_bars']), |
|
|
'verse_bars': random.choice(prompt_variables['verse_bars']), |
|
|
'chorus_bars': random.choice(prompt_variables['chorus_bars']), |
|
|
'bridge_bars': random.choice(prompt_variables['bridge_bars']), |
|
|
'outro_bars': random.choice(prompt_variables['outro_bars']), |
|
|
'build_bars': random.choice(prompt_variables['build_bars']), |
|
|
'drop_bars': random.choice(prompt_variables['drop_bars']), |
|
|
'main_bars': random.choice(prompt_variables['main_bars']), |
|
|
'breakdown_bars': random.choice(prompt_variables['breakdown_bars']), |
|
|
'head_bars': random.choice(prompt_variables['head_bars']), |
|
|
'solo_bars': random.choice(prompt_variables['solo_bars']), |
|
|
'fugue_bars': random.choice(prompt_variables['fugue_bars']), |
|
|
'coda_bars': random.choice(prompt_variables['coda_bars']), |
|
|
'theme_bars': random.choice(prompt_variables['theme_bars']), |
|
|
'development_bars': random.choice(prompt_variables['development_bars']), |
|
|
'climax_bars': random.choice(prompt_variables['climax_bars']), |
|
|
'groove_bars': random.choice(prompt_variables['groove_bars']), |
|
|
'vibe': random.choice(prompt_variables['vibe']), |
|
|
'production_style': random.choice(prompt_variables['production_style']) |
|
|
} |
|
|
try: |
|
|
formatted_prompt = base_prompt.format(**prompt_dict) |
|
|
words = re.findall(r'\b\w+\b', formatted_prompt.lower()) |
|
|
val_list = [] |
|
|
for k, v in prompt_variables.items(): |
|
|
if isinstance(v, list): |
|
|
val_list.extend(v) |
|
|
if not any(word in val_list for word in words): |
|
|
formatted_prompt = f"{prompt_dict['style']} music with {prompt_dict['guitar_style']} guitar, {prompt_dict['bass_style']} bass, {prompt_dict['drum_style']} drums in {prompt_dict['key']} at {prompt_dict['bpm']} BPM" |
|
|
except KeyError: |
|
|
formatted_prompt = f"{prompt_dict['style']} music with {prompt_dict['guitar_style']} guitar, {prompt_dict['bass_style']} bass, {prompt_dict['drum_style']} drums in {prompt_dict['key']} at {prompt_dict['bpm']} BPM" |
|
|
return formatted_prompt, prompt_dict['style'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_chunk_oom_safe(model, text_prompt, continuation_prompt, cfg_scale, top_k, top_p, temperature, target_duration): |
|
|
durations_to_try = [target_duration, 20, 15, 12, 10, 8, 6, 4, 3, 2] |
|
|
for dur in durations_to_try: |
|
|
try: |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.empty_cache() |
|
|
model.set_generation_params( |
|
|
duration=dur, |
|
|
use_sampling=True, |
|
|
top_k=int(top_k), |
|
|
top_p=float(top_p), |
|
|
temperature=float(temperature), |
|
|
cfg_coef=float(cfg_scale), |
|
|
two_step_cfg=False |
|
|
) |
|
|
with torch.no_grad(): |
|
|
with autocast('cuda', dtype=AUTOCAST_DTYPE): |
|
|
if continuation_prompt is None: |
|
|
|
|
|
audio_chunk = model.generate([text_prompt], progress=False)[0] |
|
|
else: |
|
|
audio_chunk = model.generate_continuation( |
|
|
continuation_prompt, model.sample_rate, [text_prompt], progress=False |
|
|
)[0] |
|
|
return audio_chunk, dur |
|
|
except RuntimeError as e: |
|
|
msg = str(e).lower() |
|
|
if "out of memory" in msg or "cuda error" in msg: |
|
|
print(f"OOM at duration {dur}s — retrying with smaller chunk...") |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
continue |
|
|
else: |
|
|
raise |
|
|
raise RuntimeError("Failed to generate audio chunk without CUDA OOM.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, volume_db: float, genre: str = None): |
|
|
global musicgen_model |
|
|
global api_status |
|
|
api_status = "rendering" |
|
|
|
|
|
if not instrumental_prompt.strip() and not genre: |
|
|
instrumental_prompt, style = get_genre_prompt("nirvana") |
|
|
elif not instrumental_prompt.strip(): |
|
|
instrumental_prompt, style = get_genre_prompt(genre) |
|
|
else: |
|
|
words = re.findall(r'\b\w+\b', instrumental_prompt.lower()) |
|
|
val_list = [] |
|
|
for k, v in prompt_variables.items(): |
|
|
if isinstance(v, list): |
|
|
val_list.extend(v) |
|
|
if not any(word in val_list for word in words): |
|
|
instrumental_prompt, style = get_genre_prompt("nirvana") |
|
|
else: |
|
|
ek = extract_song_keyword(instrumental_prompt) |
|
|
style = ek if ek in prompt_variables['style'] else random.choice(prompt_variables['style']) |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
base_chunk_target = 30 |
|
|
total_duration = max(total_duration, 30) |
|
|
remaining = total_duration |
|
|
audio_chunks = [] |
|
|
chunk_paths = [] |
|
|
continuation_prompt = None |
|
|
chunk_index = 0 |
|
|
|
|
|
|
|
|
existing_titles = [] |
|
|
if os.path.exists(metadata_file): |
|
|
with open(metadata_file, 'r') as f: |
|
|
songs_metadata = json.load(f) |
|
|
existing_titles = [entry["title"] for entry in songs_metadata] |
|
|
song_keyword = extract_song_keyword(instrumental_prompt) |
|
|
title_base, band_name = generate_unique_title(existing_titles, genre if genre else "nirvana", song_keyword, style) |
|
|
|
|
|
|
|
|
while remaining > 0: |
|
|
target = min(base_chunk_target, remaining) |
|
|
print_resource_usage(f"Before Chunk {chunk_index + 1}") |
|
|
try: |
|
|
audio_chunk, actual_dur = generate_chunk_oom_safe( |
|
|
musicgen_model, instrumental_prompt, continuation_prompt, cfg_scale, top_k, top_p, temperature, target |
|
|
) |
|
|
audio_chunk = audio_chunk.cpu().to(dtype=torch.float32) |
|
|
if audio_chunk.dim() == 1: |
|
|
audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0) |
|
|
elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1: |
|
|
audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
|
|
elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2: |
|
|
audio_chunk = audio_chunk[:1, :] |
|
|
audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
|
|
elif audio_chunk.dim() > 2: |
|
|
audio_chunk = audio_chunk.view(2, -1) |
|
|
if audio_chunk.shape[0] != 2: |
|
|
raise ValueError(f"Expected stereo audio with shape (2, samples), got {audio_chunk.shape}") |
|
|
|
|
|
|
|
|
samples_per_second = musicgen_model.sample_rate |
|
|
tail_sec = 2 |
|
|
tail_samples = min(int(tail_sec * samples_per_second), audio_chunk.shape[1] - 1 if audio_chunk.shape[1] > 1 else 1) |
|
|
if tail_samples > 0: |
|
|
continuation_prompt = audio_chunk[:, -tail_samples:].cpu() |
|
|
else: |
|
|
continuation_prompt = None |
|
|
|
|
|
|
|
|
temp_wav_path = os.path.join(output_dir, f"temp_{random.randint(100, 999)}_{chunk_index}.wav") |
|
|
try: |
|
|
torchaudio.save(temp_wav_path, audio_chunk, musicgen_model.sample_rate, bits_per_sample=16) |
|
|
final_segment = AudioSegment.from_wav(temp_wav_path) |
|
|
finally: |
|
|
if os.path.exists(temp_wav_path): |
|
|
os.remove(temp_wav_path) |
|
|
del audio_chunk |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
print(f"Post-processing chunk {chunk_index + 1} (duration ~{actual_dur}s)...") |
|
|
final_segment = apply_eq(final_segment) |
|
|
final_segment = apply_limiter(final_segment, max_db=volume_db, target_lufs=-16.0) |
|
|
if chunk_index == 0: |
|
|
final_segment = final_segment.fade_in(1000) |
|
|
|
|
|
if remaining - actual_dur <= 0: |
|
|
final_segment = final_segment.fade_out(1000) |
|
|
|
|
|
|
|
|
mp3_filename = f"{title_base.lower()}_{song_keyword}_{style}_{band_name}_chunk{chunk_index + 1}.mp3" |
|
|
mp3_path = os.path.join(output_dir, mp3_filename) |
|
|
final_segment.export( |
|
|
mp3_path, |
|
|
format="mp3", |
|
|
bitrate="64k", |
|
|
tags={"title": f"{title_base}_Chunk{chunk_index + 1}", "artist": "GhostAI"} |
|
|
) |
|
|
print(f"Saved chunk {chunk_index + 1} to {mp3_path}") |
|
|
audio_chunks.append(final_segment) |
|
|
chunk_paths.append(mp3_path) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"title": f"{title_base}_Chunk{chunk_index + 1}", |
|
|
"filename": mp3_filename, |
|
|
"prompt": instrumental_prompt, |
|
|
"duration": actual_dur, |
|
|
"volume_db": volume_db, |
|
|
"target_lufs": -16.0, |
|
|
"timestamp": datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), |
|
|
"file_path": mp3_path, |
|
|
"sample_rate": musicgen_model.sample_rate, |
|
|
"style": style, |
|
|
"band_name": band_name, |
|
|
"chunk_index": chunk_index + 1 |
|
|
} |
|
|
update_metadata_storage(metadata) |
|
|
|
|
|
chunk_index += 1 |
|
|
remaining -= actual_dur |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
print_resource_usage(f"After Chunk {chunk_index}") |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to process chunk {chunk_index + 1}: {e}") |
|
|
api_status = "idle" |
|
|
raise |
|
|
|
|
|
|
|
|
if len(audio_chunks) > 1: |
|
|
combined_segment = audio_chunks[0] |
|
|
for segment in audio_chunks[1:]: |
|
|
combined_segment = combined_segment.append(segment, crossfade=500) |
|
|
combined_mp3_filename = f"{title_base.lower()}_{song_keyword}_{style}_{band_name}_combined.mp3" |
|
|
combined_mp3_path = os.path.join(output_dir, combined_mp3_filename) |
|
|
combined_segment.export( |
|
|
combined_mp3_path, |
|
|
format="mp3", |
|
|
bitrate="64k", |
|
|
tags={"title": title_base, "artist": "GhostAI"} |
|
|
) |
|
|
print(f"Saved combined audio to {combined_mp3_path}") |
|
|
metadata = { |
|
|
"title": title_base, |
|
|
"filename": combined_mp3_filename, |
|
|
"prompt": instrumental_prompt, |
|
|
"duration": total_duration, |
|
|
"volume_db": volume_db, |
|
|
"target_lufs": -16.0, |
|
|
"timestamp": datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), |
|
|
"file_path": combined_mp3_path, |
|
|
"sample_rate": musicgen_model.sample_rate, |
|
|
"style": style, |
|
|
"band_name": band_name, |
|
|
"chunk_index": 0 |
|
|
} |
|
|
update_metadata_storage(metadata) |
|
|
del combined_segment, audio_chunks |
|
|
gc.collect() |
|
|
api_status = "idle" |
|
|
return combined_mp3_path, "✅ Done!", False, gr.update(value=load_renders()[0]) |
|
|
else: |
|
|
|
|
|
print(f"Saved metadata to {metadata_file}") |
|
|
del audio_chunks |
|
|
gc.collect() |
|
|
api_status = "idle" |
|
|
return chunk_paths[0], "✅ Done!", False, gr.update(value=load_renders()[0]) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed: {e}") |
|
|
api_status = "idle" |
|
|
return None, f"❌ Failed: {e}", False, gr.update(value=load_renders()[0]) |
|
|
finally: |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
def clear_inputs(): |
|
|
return "", 3.0, 50, 0.0, 0.8, 30, -24.0, False |
|
|
|
|
|
def show_render_wheel(): |
|
|
return True |
|
|
|
|
|
def set_genre_prompt(genre: str): |
|
|
prompt, _ = get_genre_prompt(genre) |
|
|
return prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
gr.Markdown(""" |
|
|
<div class="header-container" role="banner" aria-label="GhostAI Music Generator"> |
|
|
<h1>GhostAI Music Generator</h1> |
|
|
<p>Create Professional Instrumental Tracks</p> |
|
|
</div> |
|
|
""") |
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Generate", id="generate"): |
|
|
with gr.Column(elem_classes="input-container"): |
|
|
gr.Markdown("### Instrumental Prompt") |
|
|
instrumental_prompt = gr.Textbox( |
|
|
label="Instrumental Prompt", |
|
|
placeholder="Select a genre or enter a custom prompt (e.g., 'coolriff grunge')", |
|
|
lines=4, |
|
|
elem_classes="textbox" |
|
|
) |
|
|
with gr.Row(elem_classes="genre-buttons"): |
|
|
classic_rock_btn = gr.Button("Classic Rock", elem_classes="genre-btn") |
|
|
alternative_rock_btn = gr.Button("Alternative Rock", elem_classes="genre-btn") |
|
|
detroit_techno_btn = gr.Button("Detroit Techno", elem_classes="genre-btn") |
|
|
deep_house_btn = gr.Button("Deep House", elem_classes="genre-btn") |
|
|
smooth_jazz_btn = gr.Button("Smooth Jazz", elem_classes="genre-btn") |
|
|
bebop_jazz_btn = gr.Button("Bebop Jazz", elem_classes="genre-btn") |
|
|
baroque_classical_btn = gr.Button("Baroque Classical", elem_classes="genre-btn") |
|
|
romantic_classical_btn = gr.Button("Romantic Classical", elem_classes="genre-btn") |
|
|
boom_bap_hiphop_btn = gr.Button("Boom Bap Hip-Hop", elem_classes="genre-btn") |
|
|
trap_hiphop_btn = gr.Button("Trap Hip-Hop", elem_classes="genre-btn") |
|
|
pop_rock_btn = gr.Button("Pop Rock", elem_classes="genre-btn") |
|
|
fusion_jazz_btn = gr.Button("Fusion Jazz", elem_classes="genre-btn") |
|
|
edm_btn = gr.Button("EDM", elem_classes="genre-btn") |
|
|
indie_folk_btn = gr.Button("Indie Folk", elem_classes="genre-btn") |
|
|
star_wars_btn = gr.Button("Star Wars Epic", elem_classes="genre-btn") |
|
|
star_wars_classical_btn = gr.Button("Star Wars Classical", elem_classes="genre-btn") |
|
|
nirvana_btn = gr.Button("Nirvana", elem_classes="genre-btn") |
|
|
wutang_btn = gr.Button("Wu-Tang", elem_classes="genre-btn") |
|
|
milesdavis_btn = gr.Button("Miles Davis", elem_classes="genre-btn") |
|
|
with gr.Column(elem_classes="settings-container"): |
|
|
gr.Markdown("### Generation Settings") |
|
|
cfg_scale = gr.Slider( |
|
|
label="Guidance Scale (CFG)", |
|
|
minimum=1.0, |
|
|
maximum=10.0, |
|
|
value=3.0, |
|
|
step=0.1 |
|
|
) |
|
|
top_k = gr.Slider( |
|
|
label="Top-K Sampling", |
|
|
minimum=10, |
|
|
maximum=500, |
|
|
value=50, |
|
|
step=10 |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
label="Top-P Sampling", |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.0, |
|
|
step=0.1 |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
label="Temperature", |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=0.8, |
|
|
step=0.1 |
|
|
) |
|
|
total_duration = gr.Slider( |
|
|
label="Duration (seconds)", |
|
|
minimum=30, |
|
|
maximum=300, |
|
|
value=30, |
|
|
step=10 |
|
|
) |
|
|
volume_db = gr.Slider( |
|
|
label="Output Volume (dBFS)", |
|
|
minimum=-30.0, |
|
|
maximum=0.0, |
|
|
value=-24.0, |
|
|
step=0.1 |
|
|
) |
|
|
with gr.Row(elem_classes="action-buttons"): |
|
|
gen_btn = gr.Button("Generate Music") |
|
|
clr_btn = gr.Button("Clear Inputs") |
|
|
with gr.Column(elem_classes="output-container"): |
|
|
gr.Markdown("### Output") |
|
|
render_wheel = gr.HTML('<div class="render-wheel" aria-live="polite">Generating...</div>', label="Rendering Status") |
|
|
render_state = gr.State(value=False) |
|
|
out_audio = gr.Audio(label="Generated Track", type="filepath", interactive=True, elem_classes="audio-container") |
|
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
with gr.Tab("Renders", id="renders"): |
|
|
with gr.Column(elem_classes="renders-container"): |
|
|
gr.Markdown("### Browse Renders") |
|
|
renders_table = gr.DataFrame( |
|
|
headers=["Title", "Filename", "Prompt", "Duration (s)", "Timestamp", "Audio", "Download", "Chunk"], |
|
|
datatype=["str", "str", "str", "number", "str", "audio", "html", "number"], |
|
|
interactive=False, |
|
|
value=load_renders()[0], |
|
|
elem_classes="renders-table" |
|
|
) |
|
|
renders_status = gr.Textbox(label="Renders Status", interactive=False, value=load_renders()[1]) |
|
|
|
|
|
|
|
|
classic_rock_btn.click(set_genre_prompt, inputs=[gr.State(value="classic_rock")], outputs=[instrumental_prompt]) |
|
|
alternative_rock_btn.click(set_genre_prompt, inputs=[gr.State(value="alternative_rock")], outputs=[instrumental_prompt]) |
|
|
detroit_techno_btn.click(set_genre_prompt, inputs=[gr.State(value="detroit_techno")], outputs=[instrumental_prompt]) |
|
|
deep_house_btn.click(set_genre_prompt, inputs=[gr.State(value="deep_house")], outputs=[instrumental_prompt]) |
|
|
smooth_jazz_btn.click(set_genre_prompt, inputs=[gr.State(value="smooth_jazz")], outputs=[instrumental_prompt]) |
|
|
bebop_jazz_btn.click(set_genre_prompt, inputs=[gr.State(value="bebop_jazz")], outputs=[instrumental_prompt]) |
|
|
baroque_classical_btn.click(set_genre_prompt, inputs=[gr.State(value="baroque_classical")], outputs=[instrumental_prompt]) |
|
|
romantic_classical_btn.click(set_genre_prompt, inputs=[gr.State(value="romantic_classical")], outputs=[instrumental_prompt]) |
|
|
boom_bap_hiphop_btn.click(set_genre_prompt, inputs=[gr.State(value="boom_bap_hiphop")], outputs=[instrumental_prompt]) |
|
|
trap_hiphop_btn.click(set_genre_prompt, inputs=[gr.State(value="trap_hiphop")], outputs=[instrumental_prompt]) |
|
|
pop_rock_btn.click(set_genre_prompt, inputs=[gr.State(value="pop_rock")], outputs=[instrumental_prompt]) |
|
|
fusion_jazz_btn.click(set_genre_prompt, inputs=[gr.State(value="fusion_jazz")], outputs=[instrumental_prompt]) |
|
|
edm_btn.click(set_genre_prompt, inputs=[gr.State(value="edm")], outputs=[instrumental_prompt]) |
|
|
indie_folk_btn.click(set_genre_prompt, inputs=[gr.State(value="indie_folk")], outputs=[instrumental_prompt]) |
|
|
star_wars_btn.click(set_genre_prompt, inputs=[gr.State(value="star_wars")], outputs=[instrumental_prompt]) |
|
|
star_wars_classical_btn.click(set_genre_prompt, inputs=[gr.State(value="star_wars_classical")], outputs=[instrumental_prompt]) |
|
|
nirvana_btn.click(set_genre_prompt, inputs=[gr.State(value="nirvana")], outputs=[instrumental_prompt]) |
|
|
wutang_btn.click(set_genre_prompt, inputs=[gr.State(value="wutang")], outputs=[instrumental_prompt]) |
|
|
milesdavis_btn.click(set_genre_prompt, inputs=[gr.State(value="milesdavis")], outputs=[instrumental_prompt]) |
|
|
gen_btn.click( |
|
|
fn=show_render_wheel, |
|
|
inputs=None, |
|
|
outputs=[render_state], |
|
|
).then( |
|
|
fn=generate_music, |
|
|
inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, volume_db, gr.State(None)], |
|
|
outputs=[out_audio, status, render_state, renders_table], |
|
|
show_progress="full" |
|
|
) |
|
|
clr_btn.click( |
|
|
fn=clear_inputs, |
|
|
inputs=None, |
|
|
outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, volume_db, render_state] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class MusicRequest(BaseModel): |
|
|
prompt: str = None |
|
|
duration: int = 30 |
|
|
volume_db: float = -24.0 |
|
|
genre: str = None |
|
|
|
|
|
@app.get("/prompts/") |
|
|
async def get_prompts(): |
|
|
global api_status |
|
|
try: |
|
|
prompts = list(config['Prompts'].keys()) |
|
|
return {"status": api_status, "prompts": prompts} |
|
|
except Exception as e: |
|
|
print(f"Error fetching prompts: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Error fetching prompts: {e}") |
|
|
|
|
|
@app.post("/generate-music/") |
|
|
async def api_generate_music(request: MusicRequest): |
|
|
global api_status |
|
|
api_status = "rendering" |
|
|
try: |
|
|
instrumental_prompt = ( |
|
|
get_genre_prompt(request.genre)[0] if request.genre else |
|
|
request.prompt if request.prompt else |
|
|
get_genre_prompt("nirvana")[0] |
|
|
) |
|
|
style = ( |
|
|
get_genre_prompt(request.genre)[1] if request.genre else |
|
|
extract_song_keyword(request.prompt) if request.prompt and extract_song_keyword(request.prompt) in prompt_variables['style'] else |
|
|
get_genre_prompt("nirvana")[1] |
|
|
) |
|
|
if not instrumental_prompt.strip(): |
|
|
api_status = "idle" |
|
|
raise HTTPException(status_code=400, detail="Invalid prompt or genre") |
|
|
|
|
|
total_duration = max(request.duration, 30) |
|
|
remaining = total_duration |
|
|
audio_chunks = [] |
|
|
chunk_paths = [] |
|
|
continuation_prompt = None |
|
|
chunk_index = 0 |
|
|
|
|
|
existing_titles = [] |
|
|
if os.path.exists(metadata_file): |
|
|
with open(metadata_file, 'r') as f: |
|
|
songs_metadata = json.load(f) |
|
|
existing_titles = [entry["title"] for entry in songs_metadata] |
|
|
song_keyword = extract_song_keyword(request.prompt if request.prompt else instrumental_prompt) |
|
|
title_base, band_name = generate_unique_title(existing_titles, request.genre if request.genre else "nirvana", song_keyword, style) |
|
|
|
|
|
while remaining > 0: |
|
|
target = min(30, remaining) |
|
|
print_resource_usage(f"Before API Chunk {chunk_index + 1}") |
|
|
try: |
|
|
audio_chunk, actual_dur = generate_chunk_oom_safe( |
|
|
musicgen_model, instrumental_prompt, continuation_prompt, 3.0, 50, 0.0, 0.8, target |
|
|
) |
|
|
audio_chunk = audio_chunk.cpu().to(dtype=torch.float32) |
|
|
if audio_chunk.dim() == 1: |
|
|
audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0) |
|
|
elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1: |
|
|
audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
|
|
elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2: |
|
|
audio_chunk = audio_chunk[:1, :] |
|
|
audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
|
|
elif audio_chunk.dim() > 2: |
|
|
audio_chunk = audio_chunk.view(2, -1) |
|
|
if audio_chunk.shape[0] != 2: |
|
|
raise ValueError(f"Expected stereo audio with shape (2, samples), got {audio_chunk.shape}") |
|
|
|
|
|
samples_per_second = musicgen_model.sample_rate |
|
|
tail_sec = 2 |
|
|
tail_samples = min(int(tail_sec * samples_per_second), audio_chunk.shape[1] - 1 if audio_chunk.shape[1] > 1 else 1) |
|
|
continuation_prompt = audio_chunk[:, -tail_samples:].cpu() if tail_samples > 0 else None |
|
|
|
|
|
temp_wav_path = os.path.join(output_dir, f"temp_{random.randint(100, 999)}_{chunk_index}.wav") |
|
|
try: |
|
|
torchaudio.save(temp_wav_path, audio_chunk, musicgen_model.sample_rate, bits_per_sample=16) |
|
|
final_segment = AudioSegment.from_wav(temp_wav_path) |
|
|
finally: |
|
|
if os.path.exists(temp_wav_path): |
|
|
os.remove(temp_wav_path) |
|
|
del audio_chunk |
|
|
gc.collect() |
|
|
|
|
|
final_segment = apply_eq(final_segment) |
|
|
final_segment = apply_limiter(final_segment, max_db=request.volume_db, target_lufs=-16.0) |
|
|
if chunk_index == 0: |
|
|
final_segment = final_segment.fade_in(1000) |
|
|
if remaining - actual_dur <= 0: |
|
|
final_segment = final_segment.fade_out(1000) |
|
|
|
|
|
mp3_filename = f"{title_base.lower()}_{song_keyword}_{style}_{band_name}_chunk{chunk_index + 1}.mp3" |
|
|
mp3_path = os.path.join(output_dir, mp3_filename) |
|
|
final_segment.export( |
|
|
mp3_path, |
|
|
format="mp3", |
|
|
bitrate="64k", |
|
|
tags={"title": f"{title_base}_Chunk{chunk_index + 1}", "artist": "GhostAI"} |
|
|
) |
|
|
print(f"Saved API chunk {chunk_index + 1} to {mp3_path}") |
|
|
audio_chunks.append(final_segment) |
|
|
chunk_paths.append(mp3_path) |
|
|
|
|
|
metadata = { |
|
|
"title": f"{title_base}_Chunk{chunk_index + 1}", |
|
|
"filename": mp3_filename, |
|
|
"prompt": instrumental_prompt, |
|
|
"duration": actual_dur, |
|
|
"volume_db": request.volume_db, |
|
|
"target_lufs": -16.0, |
|
|
"timestamp": datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), |
|
|
"file_path": mp3_path, |
|
|
"sample_rate": musicgen_model.sample_rate, |
|
|
"style": style, |
|
|
"band_name": band_name, |
|
|
"chunk_index": chunk_index + 1 |
|
|
} |
|
|
update_metadata_storage(metadata) |
|
|
|
|
|
chunk_index += 1 |
|
|
remaining -= actual_dur |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
print_resource_usage(f"After API Chunk {chunk_index}") |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to process API chunk {chunk_index + 1}: {e}") |
|
|
api_status = "idle" |
|
|
raise |
|
|
|
|
|
if len(audio_chunks) > 1: |
|
|
combined_segment = audio_chunks[0] |
|
|
for segment in audio_chunks[1:]: |
|
|
combined_segment = combined_segment.append(segment, crossfade=500) |
|
|
combined_mp3_filename = f"{title_base.lower()}_{song_keyword}_{style}_{band_name}_combined.mp3" |
|
|
combined_mp3_path = os.path.join(output_dir, combined_mp3_filename) |
|
|
combined_segment.export( |
|
|
combined_mp3_path, |
|
|
format="mp3", |
|
|
bitrate="64k", |
|
|
tags={"title": title_base, "artist": "GhostAI"} |
|
|
) |
|
|
print(f"Saved combined audio to {combined_mp3_path}") |
|
|
metadata = { |
|
|
"title": title_base, |
|
|
"filename": combined_mp3_filename, |
|
|
"prompt": instrumental_prompt, |
|
|
"duration": total_duration, |
|
|
"volume_db": request.volume_db, |
|
|
"target_lufs": -16.0, |
|
|
"timestamp": datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), |
|
|
"file_path": combined_mp3_path, |
|
|
"sample_rate": musicgen_model.sample_rate, |
|
|
"style": style, |
|
|
"band_name": band_name, |
|
|
"chunk_index": 0 |
|
|
} |
|
|
update_metadata_storage(metadata) |
|
|
del combined_segment, audio_chunks |
|
|
gc.collect() |
|
|
api_status = "idle" |
|
|
return FileResponse(combined_mp3_path, media_type="audio/mpeg") |
|
|
else: |
|
|
print(f"Saved metadata to {metadata_file}") |
|
|
del audio_chunks |
|
|
gc.collect() |
|
|
api_status = "idle" |
|
|
return FileResponse(chunk_paths[0], media_type="audio/mpeg") |
|
|
except Exception as e: |
|
|
print(f"Error generating music: {e}") |
|
|
api_status = "idle" |
|
|
raise HTTPException(status_code=500, detail=f"Error generating music: {e}") |
|
|
finally: |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
@app.get("/get-song/{filename}") |
|
|
async def get_song(filename: str): |
|
|
global api_status |
|
|
file_path = os.path.join(output_dir, filename) |
|
|
if not os.path.exists(file_path): |
|
|
print(f"Error: Song file {filename} not found") |
|
|
raise HTTPException(status_code=404, detail="Song file not found") |
|
|
print(f"Serving file: {filename}") |
|
|
return FileResponse(file_path, media_type="audio/mpeg", filename=filename) |
|
|
|
|
|
@app.get("/status/") |
|
|
async def get_status(): |
|
|
global api_status |
|
|
return {"status": api_status} |
|
|
|
|
|
def run_fastapi(): |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
fastapi_process = multiprocessing.Process(target=run_fastapi) |
|
|
fastapi_process.start() |
|
|
try: |
|
|
demo.launch(server_name="0.0.0.0", server_port=9999, share=False, inbrowser=True, show_error=True) |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to launch Gradio: {e}") |
|
|
fastapi_process.terminate() |
|
|
sys.exit(1) |
|
|
finally: |
|
|
fastapi_process.terminate() |
|
|
|