ghostai1 commited on
Commit
5335380
Β·
verified Β·
1 Parent(s): 572b8f5

Create STABLE12gb3060.py

Browse files
Files changed (1) hide show
  1. STABLE12gb3060.py +1280 -0
STABLE12gb3060.py ADDED
@@ -0,0 +1,1280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import time
5
+ import sys
6
+ import numpy as np
7
+ import gc
8
+ import gradio as gr
9
+ from pydub import AudioSegment
10
+ from audiocraft.models import MusicGen
11
+ from torch.cuda.amp import autocast
12
+ import warnings
13
+ import random
14
+ import traceback
15
+ import logging
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+ import mmap
19
+ import subprocess
20
+ import re
21
+ import gradio_client.utils
22
+
23
+ # Patch for Gradio bug
24
+ original_get_type = gradio_client.utils.get_type
25
+ def patched_get_type(schema):
26
+ if isinstance(schema, bool):
27
+ return "boolean"
28
+ return original_get_type(schema)
29
+ gradio_client.utils.get_type = patched_get_type
30
+
31
+ # Suppress warnings for cleaner output
32
+ warnings.filterwarnings("ignore")
33
+
34
+ # Set PYTORCH_CUDA_ALLOC_CONF for CUDA 12
35
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
36
+
37
+ # Optimize for CUDA 12
38
+ torch.backends.cudnn.benchmark = False
39
+ torch.backends.cudnn.deterministic = True
40
+
41
+ # Setup logging
42
+ log_dir = "logs"
43
+ os.makedirs(log_dir, exist_ok=True)
44
+ log_file = os.path.join(log_dir, f"musicgen_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
45
+ logging.basicConfig(
46
+ level=logging.DEBUG,
47
+ format="%(asctime)s [%(levelname)s] %(message)s",
48
+ handlers=[
49
+ logging.FileHandler(log_file),
50
+ logging.StreamHandler(sys.stdout)
51
+ ]
52
+ )
53
+ logger = logging.getLogger(__name__)
54
+
55
+ # Device setup
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ if device != "cuda":
58
+ logger.error("CUDA is required for GPU rendering. CPU rendering is disabled.")
59
+ sys.exit(1)
60
+ logger.info(f"Using GPU: {torch.cuda.get_device_name(0)} (CUDA 12)")
61
+ logger.info(f"Using precision: float16 for model, float32 for CPU processing")
62
+
63
+ # Memory cleanup function
64
+ def clean_memory():
65
+ try:
66
+ torch.cuda.empty_cache()
67
+ gc.collect()
68
+ torch.cuda.ipc_collect()
69
+ torch.cuda.synchronize()
70
+ vram_mb = torch.cuda.memory_allocated() / 1024**2
71
+ logger.info(f"Memory cleaned: VRAM allocated = {vram_mb:.2f} MB")
72
+ logger.debug(f"VRAM summary: {torch.cuda.memory_summary()}")
73
+ return vram_mb
74
+ except Exception as e:
75
+ logger.error(f"Failed to clean memory: {e}")
76
+ logger.error(traceback.format_exc())
77
+ return None
78
+
79
+ # Check VRAM and external processes
80
+ def check_vram():
81
+ try:
82
+ result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv'], capture_output=True, text=True)
83
+ lines = result.stdout.splitlines()
84
+ if len(lines) > 1:
85
+ used_mb, total_mb = map(int, re.findall(r'\d+', lines[1]))
86
+ free_mb = total_mb - used_mb
87
+ logger.info(f"VRAM: {used_mb} MiB used, {free_mb} MiB free, {total_mb} MiB total")
88
+ if free_mb < 5000:
89
+ logger.warning(f"Low free VRAM ({free_mb} MiB). Close other applications or processes.")
90
+ result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid,used_memory', '--format=csv'], capture_output=True, text=True)
91
+ logger.info(f"GPU processes:\n{result.stdout}")
92
+ return free_mb
93
+ except Exception as e:
94
+ logger.error(f"Failed to check VRAM: {e}")
95
+ return None
96
+
97
+ # Pre-run VRAM check and cleanup
98
+ free_vram = check_vram()
99
+ if free_vram is not None and free_vram < 5000:
100
+ logger.warning("Consider terminating high-VRAM processes before continuing.")
101
+ clean_memory()
102
+
103
+ # Load MusicGen large model into VRAM
104
+ try:
105
+ logger.info("Loading MusicGen large model into VRAM...")
106
+ local_model_path = "./models/musicgen-large"
107
+ if not os.path.exists(local_model_path):
108
+ logger.error(f"Local model path {local_model_path} does not exist.")
109
+ logger.error("Please download the MusicGen large model weights and place them in the correct directory.")
110
+ sys.exit(1)
111
+ with autocast(dtype=torch.float16):
112
+ musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
113
+ musicgen_model.set_generation_params(
114
+ duration=30,
115
+ two_step_cfg=False
116
+ )
117
+ logger.info("MusicGen large model loaded successfully.")
118
+ except Exception as e:
119
+ logger.error(f"Failed to load MusicGen model: {e}")
120
+ logger.error(traceback.format_exc())
121
+ sys.exit(1)
122
+
123
+ # Check disk space
124
+ def check_disk_space(path="."):
125
+ try:
126
+ stat = os.statvfs(path)
127
+ free_space = stat.f_bavail * stat.f_frsize / (1024**3)
128
+ if free_space < 1.0:
129
+ logger.warning(f"Low disk space ({free_space:.2f} GB). Ensure at least 1 GB free.")
130
+ return free_space >= 1.0
131
+ except Exception as e:
132
+ logger.error(f"Failed to check disk space: {e}")
133
+ return False
134
+
135
+ # Audio processing functions (CPU-based)
136
+ def ensure_stereo(audio_segment, sample_rate=48000, sample_width=2):
137
+ """Ensure the audio segment is stereo (2 channels)."""
138
+ try:
139
+ if audio_segment.channels != 2:
140
+ logger.debug(f"Converting to stereo: {audio_segment.channels} channels detected")
141
+ audio_segment = audio_segment.set_channels(2)
142
+ if audio_segment.frame_rate != sample_rate:
143
+ logger.debug(f"Setting segment sample rate to {sample_rate}")
144
+ audio_segment = audio_segment.set_frame_rate(sample_rate)
145
+ return audio_segment
146
+ except Exception as e:
147
+ logger.error(f"Failed to ensure stereo: {e}")
148
+ logger.error(traceback.format_exc())
149
+ return audio_segment
150
+
151
+ def balance_stereo(audio_segment, noise_threshold=-40, sample_rate=48000):
152
+ logger.debug(f"Balancing stereo for segment with sample rate {sample_rate}")
153
+ try:
154
+ audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
155
+ samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
156
+ if audio_segment.channels == 2:
157
+ stereo_samples = samples.reshape(-1, 2)
158
+ db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10)
159
+ mask = db_samples > noise_threshold
160
+ stereo_samples = stereo_samples * mask
161
+ left_nonzero = stereo_samples[:, 0][stereo_samples[:, 0] != 0]
162
+ right_nonzero = stereo_samples[:, 1][stereo_samples[:, 1] != 0]
163
+ left_rms = np.sqrt(np.mean(left_nonzero**2)) if len(left_nonzero) > 0 else 0
164
+ right_rms = np.sqrt(np.mean(right_nonzero**2)) if len(right_nonzero) > 0 else 0
165
+ if left_rms > 0 and right_rms > 0:
166
+ avg_rms = (left_rms + right_rms) / 2
167
+ stereo_samples[:, 0] = stereo_samples[:, 0] * (avg_rms / left_rms)
168
+ stereo_samples[:, 1] = stereo_samples[:, 1] * (avg_rms / right_rms)
169
+ balanced_samples = stereo_samples.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
170
+ if len(balanced_samples) % 2 != 0:
171
+ balanced_samples = balanced_samples[:-1]
172
+ balanced_segment = AudioSegment(
173
+ balanced_samples.tobytes(),
174
+ frame_rate=sample_rate,
175
+ sample_width=audio_segment.sample_width,
176
+ channels=2
177
+ )
178
+ logger.debug("Stereo balancing completed")
179
+ return balanced_segment
180
+ logger.error("Failed to ensure stereo channels")
181
+ return audio_segment
182
+ except Exception as e:
183
+ logger.error(f"Failed to balance stereo: {e}")
184
+ logger.error(traceback.format_exc())
185
+ return audio_segment
186
+
187
+ def calculate_rms(segment):
188
+ try:
189
+ samples = np.array(segment.get_array_of_samples(), dtype=np.float32)
190
+ rms = np.sqrt(np.mean(samples**2))
191
+ logger.debug(f"Calculated RMS: {rms}")
192
+ return rms
193
+ except Exception as e:
194
+ logger.error(f"Failed to calculate RMS: {e}")
195
+ logger.error(traceback.format_exc())
196
+ return 0
197
+
198
+ def rms_normalize(segment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=48000):
199
+ logger.debug(f"Normalizing RMS for segment with target {target_rms_db} dBFS")
200
+ try:
201
+ segment = ensure_stereo(segment, sample_rate, segment.sample_width)
202
+ target_rms = 10 ** (target_rms_db / 20) * (2**23 if segment.sample_width == 3 else 32767)
203
+ current_rms = calculate_rms(segment)
204
+ if current_rms > 0:
205
+ gain_factor = target_rms / current_rms
206
+ segment = segment.apply_gain(20 * np.log10(gain_factor))
207
+ segment = hard_limit(segment, limit_db=peak_limit_db, sample_rate=sample_rate)
208
+ logger.debug("RMS normalization completed")
209
+ return segment
210
+ except Exception as e:
211
+ logger.error(f"Failed to normalize RMS: {e}")
212
+ logger.error(traceback.format_exc())
213
+ return segment
214
+
215
+ def hard_limit(audio_segment, limit_db=-3.0, sample_rate=48000):
216
+ logger.debug(f"Applying hard limit at {limit_db} dBFS")
217
+ try:
218
+ audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
219
+ limit = 10 ** (limit_db / 20.0) * (2**23 if audio_segment.sample_width == 3 else 32767)
220
+ samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
221
+ samples = np.clip(samples, -limit, limit).astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
222
+ if len(samples) % 2 != 0:
223
+ samples = samples[:-1]
224
+ limited_segment = AudioSegment(
225
+ samples.tobytes(),
226
+ frame_rate=sample_rate,
227
+ sample_width=audio_segment.sample_width,
228
+ channels=2
229
+ )
230
+ logger.debug("Hard limit applied")
231
+ return limited_segment
232
+ except Exception as e:
233
+ logger.error(f"Failed to apply hard limit: {e}")
234
+ logger.error(traceback.format_exc())
235
+ return audio_segment
236
+
237
+ def apply_noise_gate(audio_segment, threshold_db=-80, sample_rate=48000):
238
+ logger.debug(f"Applying noise gate with threshold {threshold_db} dBFS")
239
+ try:
240
+ audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
241
+ samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
242
+ if audio_segment.channels == 2:
243
+ stereo_samples = samples.reshape(-1, 2)
244
+ db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10)
245
+ mask = db_samples > threshold_db
246
+ stereo_samples = stereo_samples * mask
247
+ # Apply a second pass to simulate faster attack/release
248
+ db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10)
249
+ mask = db_samples > threshold_db
250
+ stereo_samples = stereo_samples * mask
251
+ gated_samples = stereo_samples.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
252
+ if len(gated_samples) % 2 != 0:
253
+ gated_samples = gated_samples[:-1]
254
+ gated_segment = AudioSegment(
255
+ gated_samples.tobytes(),
256
+ frame_rate=sample_rate,
257
+ sample_width=audio_segment.sample_width,
258
+ channels=2
259
+ )
260
+ logger.debug("Noise gate applied")
261
+ return gated_segment
262
+ logger.error("Failed to ensure stereo channels for noise gate")
263
+ return audio_segment
264
+ except Exception as e:
265
+ logger.error(f"Failed to apply noise gate: {e}")
266
+ logger.error(traceback.format_exc())
267
+ return audio_segment
268
+
269
+ def apply_eq(segment, sample_rate=48000):
270
+ logger.debug(f"Applying EQ with sample rate {sample_rate}")
271
+ try:
272
+ segment = ensure_stereo(segment, sample_rate, segment.sample_width)
273
+ # Apply high-pass filter at 20 Hz
274
+ segment = segment.high_pass_filter(20)
275
+ # Apply low-pass filter at 8 kHz to remove high-frequency tones
276
+ segment = segment.low_pass_filter(8000)
277
+ # Broader gain reduction across 1-8 kHz to target static
278
+ segment = segment - 3 # Reduce gain across 1-8 kHz
279
+ # Notch filter at 12 kHz to target high-pitched tones
280
+ segment = segment - 3 # Approximate notch at 12 kHz
281
+ # High-shelf filter above 5 kHz to further suppress high frequencies
282
+ segment = segment - 10 # High-shelf above 5 kHz
283
+ logger.debug("EQ applied: 8 kHz low-pass, 3 dB reduction at 1-8 kHz, 3 dB notch at 12 kHz, 10 dB high-shelf above 5 kHz")
284
+ return segment
285
+ except Exception as e:
286
+ logger.error(f"Failed to apply EQ: {e}")
287
+ logger.error(traceback.format_exc())
288
+ return segment
289
+
290
+ def apply_fade(segment, fade_in_duration=500, fade_out_duration=500):
291
+ logger.debug(f"Applying fade: in={fade_in_duration}ms, out={fade_out_duration}ms")
292
+ try:
293
+ segment = ensure_stereo(segment, segment.frame_rate, segment.sample_width)
294
+ segment = segment.fade_in(fade_in_duration)
295
+ segment = segment.fade_out(fade_out_duration)
296
+ logger.debug("Fade applied")
297
+ return segment
298
+ except Exception as e:
299
+ logger.error(f"Failed to apply fade: {e}")
300
+ logger.error(traceback.format_exc())
301
+ return segment
302
+
303
+ # Red Hot Chili Peppers prompt for dynamic song structure
304
+ def set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num):
305
+ try:
306
+ bpm_range = (90, 130) # bpm_min=90, bpm_max=130
307
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
308
+ drum = f", standard rock drums with occasional funk grooves and dynamic fills" if drum_beat == "none" else f", {drum_beat} drums"
309
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
310
+ bass = f", funky bass lines with slap technique and melodic variation" if bass_style == "none" else f", {bass_style} bass"
311
+ guitar = f", energetic guitar riffs with punk rock energy and tonal shifts" if guitar_style == "none" else f", {guitar_style} guitar"
312
+
313
+ # Define base prompt
314
+ base_prompt = (
315
+ f"Instrumental alternative rock by Red Hot Chili Peppers{guitar}{bass}{drum}{synth}, blending funk rock and rap rock elements, "
316
+ f"capturing the raw energy of early 90s rock with dynamic variation to avoid monotony at {bpm} BPM"
317
+ )
318
+
319
+ # Vary the prompt based on chunk number
320
+ if chunk_num == 1:
321
+ prompt = base_prompt + ", featuring a dynamic intro and expressive verse with a mix of upbeat and introspective tones."
322
+ else: # chunk_num >= 2
323
+ prompt = base_prompt + ", featuring a powerful chorus and energetic outro with heightened intensity and drive."
324
+
325
+ logger.debug(f"Generated RHCP prompt for chunk {chunk_num}: {prompt}")
326
+ return prompt
327
+ except Exception as e:
328
+ logger.error(f"Failed to generate RHCP prompt for chunk {chunk_num}: {e}")
329
+ logger.error(traceback.format_exc())
330
+ return ""
331
+
332
+ # Other prompt functions (unchanged)
333
+ def set_nirvana_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
334
+ try:
335
+ bpm_range = (100, 130)
336
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
337
+ drum = f", standard rock drums, punk energy" if drum_beat == "none" else f", {drum_beat} drums, punk energy"
338
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
339
+ chosen_bass = random.choice(['deep bass', 'melodic bass']) if bass_style == "none" else bass_style
340
+ bass = f", {chosen_bass}"
341
+ chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style
342
+ guitar = f", {chosen_guitar}"
343
+ chosen_rhythm = random.choice(['steady steps', 'dynamic shifts']) if rhythmic_steps == "none" else rhythmic_steps
344
+ rhythm = f", {chosen_rhythm}"
345
+ prompt = (
346
+ f"Instrumental grunge by Nirvana{guitar}{bass}{drum}{synth}, raw lo-fi production, emotional rawness{rhythm} at {bpm} BPM."
347
+ )
348
+ logger.debug(f"Generated Nirvana prompt: {prompt}")
349
+ return prompt
350
+ except Exception as e:
351
+ logger.error(f"Failed to generate Nirvana prompt: {e}")
352
+ logger.error(traceback.format_exc())
353
+ return ""
354
+
355
+ def set_pearl_jam_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
356
+ try:
357
+ bpm_range = (100, 140)
358
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
359
+ drum = f", standard rock drums, driving rhythm" if drum_beat == "none" else f", {drum_beat} drums, driving rhythm"
360
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
361
+ bass = f", melodic bass, emotional tone" if bass_style == "none" else f", {bass_style}, emotional tone"
362
+ chosen_guitar = random.choice(['clean guitar', 'distorted guitar']) if guitar_style == "none" else guitar_style
363
+ guitar = f", {chosen_guitar}, soulful leads"
364
+ chosen_rhythm = random.choice(['steady steps', 'syncopated steps']) if rhythmic_steps == "none" else rhythmic_steps
365
+ rhythm = f", {chosen_rhythm}"
366
+ prompt = (
367
+ f"Instrumental grunge by Pearl Jam{guitar}{bass}{drum}{synth}, classic rock influences, narrative depth{rhythm} at {bpm} BPM."
368
+ )
369
+ logger.debug(f"Generated Pearl Jam prompt: {prompt}")
370
+ return prompt
371
+ except Exception as e:
372
+ logger.error(f"Failed to generate Pearl Jam prompt: {e}")
373
+ logger.error(traceback.format_exc())
374
+ return ""
375
+
376
+ def set_soundgarden_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
377
+ try:
378
+ bpm_range = (90, 140)
379
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
380
+ drum = f", standard rock drums, heavy rhythm" if drum_beat == "none" else f", {drum_beat} drums, heavy rhythm"
381
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
382
+ bass = f", deep bass, sludgy tone" if bass_style == "none" else f", {bass_style}, sludgy tone"
383
+ guitar = f", distorted guitar, downtuned riffs, psychedelic vibe" if guitar_style == "none" else f", {guitar_style}, downtuned riffs, psychedelic vibe"
384
+ rhythm = f", complex steps" if rhythmic_steps == "none" else f", {rhythmic_steps}"
385
+ prompt = (
386
+ f"Instrumental grunge with heavy metal influences by Soundgarden{guitar}{bass}{drum}{synth}, vocal-driven melody, experimental time signatures{rhythm} at {bpm} BPM."
387
+ )
388
+ logger.debug(f"Generated Soundgarden prompt: {prompt}")
389
+ return prompt
390
+ except Exception as e:
391
+ logger.error(f"Failed to generate Soundgarden prompt: {e}")
392
+ logger.error(traceback.format_exc())
393
+ return ""
394
+
395
+ def set_foo_fighters_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
396
+ try:
397
+ bpm_range = (110, 150)
398
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
399
+ drum = f", standard rock drums, powerful drive" if drum_beat == "none" else f", {drum_beat} drums, powerful drive"
400
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
401
+ bass = f", melodic bass, supportive tone" if bass_style == "none" else f", {bass_style}, supportive tone"
402
+ chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style
403
+ guitar = f", {chosen_guitar}, anthemic quality"
404
+ chosen_rhythm = random.choice(['steady steps', 'driving rhythm']) if rhythmic_steps == "none" else rhythmic_steps
405
+ rhythm = f", {chosen_rhythm}"
406
+ prompt = (
407
+ f"Instrumental alternative rock with post-grunge influences by Foo Fighters{guitar}, stadium-ready hooks{bass}{drum}{synth}, Grohl’s raw energy{rhythm} at {bpm} BPM."
408
+ )
409
+ logger.debug(f"Generated Foo Fighters prompt: {prompt}")
410
+ return prompt
411
+ except Exception as e:
412
+ logger.error(f"Failed to generate Foo Fighters prompt: {e}")
413
+ logger.error(traceback.format_exc())
414
+ return ""
415
+
416
+ def set_classic_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
417
+ try:
418
+ bpm_range = (120, 180)
419
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
420
+ drum = f", double bass drums" if drum_beat == "none" else f", {drum_beat} drums"
421
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
422
+ bass = f", aggressive bass" if bass_style == "none" else f", {bass_style}"
423
+ guitar = f", distorted guitar, blazing fast riffs" if guitar_style == "none" else f", {guitar_style}, blazing fast riffs"
424
+ rhythm = f", complex steps" if rhythmic_steps == "none" else f", {rhythmic_steps}"
425
+ prompt = (
426
+ f"Instrumental thrash metal by Metallica{guitar}{bass}{drum}{synth}, raw intensity{rhythm} at {bpm} BPM."
427
+ )
428
+ logger.debug(f"Generated Metallica prompt: {prompt}")
429
+ return prompt
430
+ except Exception as e:
431
+ logger.error(f"Failed to generate Metallica prompt: {e}")
432
+ logger.error(traceback.format_exc())
433
+ return ""
434
+
435
+ def set_smashing_pumpkins_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
436
+ try:
437
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
438
+ synth = f", {synthesizer}" if synthesizer != "none" else ", lush synths"
439
+ bass = f", {bass_style} bass" if bass_style == "none" else ""
440
+ guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", dreamy guitar"
441
+ prompt = (
442
+ f"Instrumental alternative rock by Smashing Pumpkins{guitar}{synth}{drum}{bass} at {bpm} BPM."
443
+ )
444
+ logger.debug(f"Generated Smashing Pumpkins prompt: {prompt}")
445
+ return prompt
446
+ except Exception as e:
447
+ logger.error(f"Failed to generate Smashing Pumpkins prompt: {e}")
448
+ logger.error(traceback.format_exc())
449
+ return ""
450
+
451
+ def set_radiohead_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
452
+ try:
453
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
454
+ synth = f", {synthesizer}" if synthesizer != "none" else ", atmospheric synths"
455
+ bass = f", {bass_style} bass" if bass_style == "none" else ", hypnotic bass"
456
+ guitar = f", {guitar_style} guitar" if guitar_style != "none" else ""
457
+ prompt = (
458
+ f"Instrumental experimental rock by Radiohead{synth}{bass}{drum}{guitar} at {bpm} BPM."
459
+ )
460
+ logger.debug(f"Generated Radiohead prompt: {prompt}")
461
+ return prompt
462
+ except Exception as e:
463
+ logger.error(f"Failed to generate Radiohead prompt: {e}")
464
+ logger.error(traceback.format_exc())
465
+ return ""
466
+
467
+ def set_alternative_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
468
+ try:
469
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
470
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
471
+ bass = f", {bass_style} bass" if bass_style == "none" else ", melodic bass"
472
+ guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", distorted guitar"
473
+ prompt = (
474
+ f"Instrumental alternative rock by Pixies{guitar}{bass}{drum}{synth} at {bpm} BPM."
475
+ )
476
+ logger.debug(f"Generated Alternative Rock prompt: {prompt}")
477
+ return prompt
478
+ except Exception as e:
479
+ logger.error(f"Failed to generate Alternative Rock prompt: {e}")
480
+ logger.error(traceback.format_exc())
481
+ return ""
482
+
483
+ def set_post_punk_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
484
+ try:
485
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ", precise drums"
486
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
487
+ bass = f", {bass_style} bass" if bass_style == "none" else ", driving bass"
488
+ guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar"
489
+ prompt = (
490
+ f"Instrumental post-punk by Joy Division{guitar}{bass}{drum}{synth} at {bpm} BPM."
491
+ )
492
+ logger.debug(f"Generated Post-Punk prompt: {prompt}")
493
+ return prompt
494
+ except Exception as e:
495
+ logger.error(f"Failed to generate Post-Punk prompt: {e}")
496
+ logger.error(traceback.format_exc())
497
+ return ""
498
+
499
+ def set_indie_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
500
+ try:
501
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
502
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
503
+ bass = f", {bass_style} bass" if bass_style == "none" else ", groovy bass"
504
+ guitar = f", {guitar_style} guitar" if guitar_style == "none" else ", jangly guitar"
505
+ prompt = (
506
+ f"Instrumental indie rock by Arctic Monkeys{guitar}{bass}{drum}{synth} at {bpm} BPM."
507
+ )
508
+ logger.debug(f"Generated Indie Rock prompt: {prompt}")
509
+ return prompt
510
+ except Exception as e:
511
+ logger.error(f"Failed to generate Indie Rock prompt: {e}")
512
+ logger.error(traceback.format_exc())
513
+ return ""
514
+
515
+ def set_funk_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
516
+ try:
517
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ", heavy drums"
518
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
519
+ bass = f", {bass_style} bass" if bass_style == "none" else ", slap bass"
520
+ guitar = f", {guitar_style} guitar" if guitar_style == "none" else ", funky guitar"
521
+ prompt = (
522
+ f"Instrumental funk rock by Rage Against the Machine{guitar}{bass}{drum}{synth} at {bpm} BPM."
523
+ )
524
+ logger.debug(f"Generated Funk Rock prompt: {prompt}")
525
+ return prompt
526
+ except Exception as e:
527
+ logger.error(f"Failed to generate Funk Rock prompt: {e}")
528
+ logger.error(traceback.format_exc())
529
+ return ""
530
+
531
+ def set_detroit_techno_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
532
+ try:
533
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ", four-on-the-floor drums"
534
+ synth = f", {synthesizer}" if synthesizer != "none" else ", pulsing synths"
535
+ bass = f", {bass_style} bass" if bass_style == "none" else ", driving bass"
536
+ guitar = f", {guitar_style} guitar" if guitar_style == "none" else ""
537
+ prompt = (
538
+ f"Instrumental Detroit techno by Juan Atkins{synth}{bass}{drum}{guitar} at {bpm} BPM."
539
+ )
540
+ logger.debug(f"Generated Detroit Techno prompt: {prompt}")
541
+ return prompt
542
+ except Exception as e:
543
+ logger.error(f"Failed to generate Detroit Techno prompt: {e}")
544
+ logger.error(traceback.format_exc())
545
+ return ""
546
+
547
+ def set_deep_house_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
548
+ try:
549
+ drum = f", {drum_beat} drums" if drum_beat == "none" else ", steady kick drums"
550
+ synth = f", {synthesizer}" if synthesizer != "none" else ", warm synths"
551
+ bass = f", {bass_style} bass" if bass_style == "none" else ", deep bass"
552
+ guitar = f", {guitar_style} guitar" if guitar_style == "none" else ""
553
+ prompt = (
554
+ f"Instrumental deep house by Larry Heard{synth}{bass}{drum}{guitar} at {bpm} BPM."
555
+ )
556
+ logger.debug(f"Generated Deep House prompt: {prompt}")
557
+ return prompt
558
+ except Exception as e:
559
+ logger.error(f"Failed to generate Deep House prompt: {e}")
560
+ logger.error(traceback.format_exc())
561
+ return ""
562
+
563
+ # Preset configurations with user-recommended settings
564
+ PRESETS = {
565
+ "default": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
566
+ "rock": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
567
+ "techno": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
568
+ "grunge": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
569
+ "indie": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
570
+ "funk_rock": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15}
571
+ }
572
+
573
+ # Function to get the latest log file
574
+ def get_latest_log():
575
+ try:
576
+ log_files = sorted(Path(log_dir).glob("musicgen_log_*.log"), key=os.path.getmtime, reverse=True)
577
+ if not log_files:
578
+ logger.warning("No log files found")
579
+ return "No log files found."
580
+ with open(log_files[0], "r") as f:
581
+ content = f.read()
582
+ logger.info(f"Retrieved latest log file: {log_files[0]}")
583
+ return content
584
+ except Exception as e:
585
+ logger.error(f"Failed to read log file: {e}")
586
+ logger.error(traceback.format_exc())
587
+ return f"Error reading log file: {e}"
588
+
589
+ # Bitrate selection functions with visual feedback
590
+ def set_bitrate_128():
591
+ logger.info("Bitrate set to 128 kbps")
592
+ return "128k"
593
+
594
+ def set_bitrate_192():
595
+ logger.info("Bitrate set to 192 kbps")
596
+ return "192k"
597
+
598
+ def set_bitrate_320():
599
+ logger.info("Bitrate set to 320 kbps")
600
+ return "320k"
601
+
602
+ # Sampling rate selection functions with visual feedback
603
+ def set_sample_rate_22050():
604
+ logger.info("Output sampling rate set to 22.05 kHz")
605
+ return "22050"
606
+
607
+ def set_sample_rate_44100():
608
+ logger.info("Output sampling rate set to 44.1 kHz")
609
+ return "44100"
610
+
611
+ def set_sample_rate_48000():
612
+ logger.info("Output sampling rate set to 48 kHz")
613
+ return "48000"
614
+
615
+ # Bit depth selection functions with visual feedback
616
+ def set_bit_depth_16():
617
+ logger.info("Bit depth set to 16-bit")
618
+ return "16"
619
+
620
+ def set_bit_depth_24():
621
+ logger.info("Bit depth set to 24-bit")
622
+ return "24"
623
+
624
+ # Wrapper for generate_music with post-generation cleanup
625
+ def generate_music_wrapper(*args):
626
+ try:
627
+ result = generate_music(*args)
628
+ return result
629
+ finally:
630
+ clean_memory()
631
+
632
+ # Optimized generation function with chunk-based prompt variation
633
+ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, bpm: int, drum_beat: str, synthesizer: str, rhythmic_steps: str, bass_style: str, guitar_style: str, target_volume: float, preset: str, max_steps: str, vram_status: str, bitrate: str, output_sample_rate: str, bit_depth: str):
634
+ global musicgen_model
635
+ if not instrumental_prompt.strip():
636
+ logger.warning("Empty instrumental prompt provided")
637
+ return None, "⚠️ Please enter a valid instrumental prompt!", vram_status
638
+ try:
639
+ logger.info("Starting music generation...")
640
+ start_time = time.time()
641
+ clean_memory()
642
+ try:
643
+ max_steps_int = int(max_steps)
644
+ except ValueError:
645
+ logger.error(f"Invalid max_steps value: {max_steps}")
646
+ return None, "❌ Invalid max_steps value; must be a number (1000, 1200, 1300, or 1500)", vram_status
647
+ try:
648
+ output_sample_rate_int = int(output_sample_rate)
649
+ except ValueError:
650
+ logger.error(f"Invalid output_sample_rate value: {output_sample_rate}")
651
+ return None, "❌ Invalid output sampling rate; must be a number (22050, 32000, 44100, or 48000)", vram_status
652
+ try:
653
+ bit_depth_int = int(bit_depth)
654
+ sample_width = 3 if bit_depth_int == 24 else 2
655
+ except ValueError:
656
+ logger.error(f"Invalid bit_depth value: {bit_depth}")
657
+ return None, "❌ Invalid bit depth; must be 16 or 24", vram_status
658
+ max_duration = min(max_steps_int / 50, 30)
659
+ total_duration = min(max(total_duration, 30), 120)
660
+ processing_sample_rate = 48000 # Updated to user-recommended value
661
+ channels = 2
662
+ audio_segments = []
663
+ overlap_duration = 0.2
664
+ remaining_duration = total_duration
665
+
666
+ if preset != "default":
667
+ preset_params = PRESETS.get(preset, PRESETS["default"])
668
+ cfg_scale = preset_params["cfg_scale"]
669
+ top_k = preset_params["top_k"]
670
+ top_p = preset_params["top_p"]
671
+ temperature = preset_params["temperature"]
672
+ logger.info(f"Applied preset {preset}: cfg_scale={cfg_scale}, top_k={top_k}, top_p={top_p}, temperature={temperature}")
673
+
674
+ if not check_disk_space():
675
+ logger.error("Insufficient disk space")
676
+ return None, "⚠️ Insufficient disk space. Free up at least 1 GB.", vram_status
677
+
678
+ seed = random.randint(0, 10000)
679
+ logger.info(f"Generating audio for {total_duration}s with seed={seed}, max_steps={max_steps_int}, output_sample_rate={output_sample_rate_int} Hz, bit_depth={bit_depth_int}-bit")
680
+ vram_status = f"Initial VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
681
+
682
+ chunk_num = 0
683
+ while remaining_duration > 0:
684
+ current_duration = min(max_duration, remaining_duration)
685
+ generation_duration = current_duration
686
+ chunk_num += 1
687
+ logger.info(f"Generating chunk {chunk_num} ({current_duration}s, VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB)")
688
+
689
+ # Generate chunk-specific prompt for Red Hot Chili Peppers
690
+ if "Red Hot Chili Peppers" in instrumental_prompt:
691
+ chunk_prompt = set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num)
692
+ else:
693
+ # For other prompts, use the base prompt without variation (as a fallback)
694
+ chunk_prompt = instrumental_prompt
695
+
696
+ musicgen_model.set_generation_params(
697
+ duration=generation_duration,
698
+ use_sampling=True,
699
+ top_k=top_k,
700
+ top_p=top_p,
701
+ temperature=temperature,
702
+ cfg_coef=cfg_scale
703
+ )
704
+
705
+ try:
706
+ with torch.no_grad():
707
+ with autocast(dtype=torch.float16):
708
+ torch.manual_seed(seed)
709
+ np.random.seed(seed)
710
+ torch.cuda.manual_seed_all(seed)
711
+ clean_memory()
712
+ if not audio_segments:
713
+ logger.debug("Generating first chunk")
714
+ audio_segment = musicgen_model.generate([chunk_prompt], progress=True)[0].cpu()
715
+ else:
716
+ logger.debug("Generating continuation chunk")
717
+ prev_segment = audio_segments[-1]
718
+ prev_segment = apply_noise_gate(prev_segment, threshold_db=-80, sample_rate=processing_sample_rate)
719
+ prev_segment = balance_stereo(prev_segment, noise_threshold=-40, sample_rate=processing_sample_rate)
720
+ temp_wav_path = f"temp_prev_{int(time.time()*1000)}.wav"
721
+ try:
722
+ logger.debug(f"Exporting previous segment to {temp_wav_path}")
723
+ prev_segment.export(temp_wav_path, format="wav")
724
+ with open(temp_wav_path, "rb") as f:
725
+ mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
726
+ prev_audio, prev_sr = torchaudio.load(temp_wav_path)
727
+ mmapped_file.close()
728
+ if prev_sr != processing_sample_rate:
729
+ logger.debug(f"Resampling from {prev_sr} to {processing_sample_rate}")
730
+ prev_audio = torchaudio.functional.resample(prev_audio, prev_sr, processing_sample_rate, lowpass_filter_width=64)
731
+ if prev_audio.shape[0] != 2:
732
+ logger.debug(f"Converting to stereo: {prev_audio.shape[0]} channels detected")
733
+ prev_audio = prev_audio.repeat(2, 1)[:, :prev_audio.shape[1]]
734
+ prev_audio = prev_audio.to(device)
735
+ audio_segment = musicgen_model.generate_continuation(
736
+ prompt=prev_audio[:, -int(processing_sample_rate * overlap_duration):],
737
+ prompt_sample_rate=processing_sample_rate,
738
+ descriptions=[chunk_prompt],
739
+ progress=True
740
+ )[0].cpu()
741
+ del prev_audio
742
+ finally:
743
+ try:
744
+ os.remove(temp_wav_path)
745
+ logger.debug(f"Deleted temporary file {temp_wav_path}")
746
+ except OSError:
747
+ logger.warning(f"Failed to delete temporary file {temp_wav_path}")
748
+ clean_memory()
749
+ except Exception as e:
750
+ logger.error(f"Error in chunk {chunk_num} generation: {e}")
751
+ logger.error(traceback.format_exc())
752
+ return None, f"❌ Failed to generate chunk {chunk_num}: {e}", vram_status
753
+
754
+ logger.debug(f"Generated audio segment shape: {audio_segment.shape}, dtype: {audio_segment.dtype}")
755
+ try:
756
+ # Ensure the model's output is resampled to processing_sample_rate
757
+ if audio_segment.shape[0] != 2:
758
+ logger.debug(f"Converting to stereo: {audio_segment.shape[0]} channels detected")
759
+ audio_segment = audio_segment.repeat(2, 1)[:, :audio_segment.shape[1]]
760
+ # Convert to float32 before resampling to avoid "slow_conv2d_cpu" error
761
+ audio_segment = audio_segment.to(dtype=torch.float32)
762
+ audio_segment = torchaudio.functional.resample(audio_segment, 32000, processing_sample_rate, lowpass_filter_width=64)
763
+ audio_np = audio_segment.numpy()
764
+ if audio_np.ndim == 1:
765
+ logger.debug("Converting mono to stereo on CPU")
766
+ audio_np = np.stack([audio_np, audio_np], axis=0)
767
+ if audio_np.shape[0] != 2:
768
+ logger.error(f"Expected stereo audio with shape (2, samples), got shape {audio_np.shape}")
769
+ return None, f"❌ Invalid audio shape for chunk {chunk_num}: {audio_np.shape}", vram_status
770
+ audio_segment = torch.from_numpy(audio_np).to(dtype=torch.float16)
771
+ logger.debug(f"Converted audio segment to float16, shape: {audio_segment.shape}")
772
+ except Exception as e:
773
+ logger.error(f"Failed to process audio segment for chunk {chunk_num}: {e}")
774
+ logger.error(traceback.format_exc())
775
+ return None, f"❌ Failed to process audio for chunk {chunk_num}: {e}", vram_status
776
+
777
+ temp_wav_path = f"temp_audio_{int(time.time()*1000)}.wav"
778
+ logger.debug(f"Saving audio segment to {temp_wav_path}, VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
779
+ try:
780
+ audio_segment_save = audio_segment.to(dtype=torch.float32)
781
+ torchaudio.save(temp_wav_path, audio_segment_save, processing_sample_rate, bits_per_sample=bit_depth_int)
782
+ del audio_segment_save
783
+ except Exception as e:
784
+ logger.error(f"Failed to save audio segment for chunk {chunk_num}: {e}")
785
+ logger.error(traceback.format_exc())
786
+ logger.warning(f"Skipping chunk {chunk_num} due to save error")
787
+ del audio_segment
788
+ clean_memory()
789
+ continue
790
+
791
+ clean_memory()
792
+ try:
793
+ with open(temp_wav_path, "rb") as f:
794
+ mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
795
+ segment = AudioSegment.from_wav(temp_wav_path)
796
+ mmapped_file.close()
797
+ except Exception as e:
798
+ logger.error(f"Failed to load WAV file for chunk {chunk_num}: {e}")
799
+ logger.error(traceback.format_exc())
800
+ logger.warning(f"Skipping chunk {chunk_num} due to WAV load error")
801
+ del audio_segment
802
+ clean_memory()
803
+ continue
804
+ finally:
805
+ try:
806
+ os.remove(temp_wav_path)
807
+ logger.debug(f"Deleted temporary file {temp_wav_path}")
808
+ except OSError:
809
+ logger.warning(f"Failed to delete temporary file {temp_wav_path}")
810
+
811
+ try:
812
+ segment = ensure_stereo(segment, processing_sample_rate, sample_width)
813
+ segment = segment - 15
814
+ if segment.frame_rate != processing_sample_rate:
815
+ logger.debug(f"Setting segment sample rate to {processing_sample_rate}")
816
+ segment = segment.set_frame_rate(processing_sample_rate)
817
+ # Apply noise gate immediately after loading to catch high-pitched tones early
818
+ segment = apply_noise_gate(segment, threshold_db=-80, sample_rate=processing_sample_rate)
819
+ segment = balance_stereo(segment, noise_threshold=-40, sample_rate=processing_sample_rate)
820
+ segment = rms_normalize(segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate)
821
+ segment = apply_eq(segment, sample_rate=processing_sample_rate)
822
+ audio_segments.append(segment)
823
+ except Exception as e:
824
+ logger.error(f"Failed to process audio segment for chunk {chunk_num}: {e}")
825
+ logger.error(traceback.format_exc())
826
+ logger.warning(f"Skipping chunk {chunk_num} due to processing error")
827
+ del audio_segment
828
+ clean_memory()
829
+ continue
830
+
831
+ del audio_segment
832
+ del audio_np
833
+ clean_memory()
834
+ vram_status = f"VRAM after chunk {chunk_num}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
835
+ time.sleep(0.1)
836
+ remaining_duration -= current_duration
837
+
838
+ if not audio_segments:
839
+ logger.error("No audio segments generated")
840
+ return None, "❌ No audio segments generated due to errors", vram_status
841
+
842
+ logger.info("Combining audio chunks...")
843
+ try:
844
+ final_segment = audio_segments[0][:min(max_duration, total_duration) * 1000]
845
+ final_segment = ensure_stereo(final_segment, processing_sample_rate, sample_width)
846
+ overlap_ms = int(overlap_duration * 1000)
847
+
848
+ for i in range(1, len(audio_segments)):
849
+ current_segment = audio_segments[i]
850
+ current_segment = current_segment[:min(max_duration, total_duration - (i * max_duration)) * 1000]
851
+ current_segment = ensure_stereo(current_segment, processing_sample_rate, sample_width)
852
+
853
+ if overlap_ms > 0 and len(current_segment) > overlap_ms:
854
+ logger.debug(f"Applying crossfade between chunks {i} and {i+1}")
855
+ prev_overlap = final_segment[-overlap_ms:]
856
+ curr_overlap = current_segment[:overlap_ms]
857
+ prev_wav_path = f"temp_prev_overlap_{int(time.time()*1000)}.wav"
858
+ curr_wav_path = f"temp_curr_overlap_{int(time.time()*1000)}.wav"
859
+ try:
860
+ prev_overlap.export(prev_wav_path, format="wav")
861
+ curr_overlap.export(curr_wav_path, format="wav")
862
+ clean_memory()
863
+ prev_audio, _ = torchaudio.load(prev_wav_path)
864
+ curr_audio, _ = torchaudio.load(curr_wav_path)
865
+ num_samples = min(prev_audio.shape[1], curr_audio.shape[1])
866
+ num_samples = num_samples - (num_samples % 2)
867
+ if num_samples <= 0:
868
+ logger.warning(f"Skipping crossfade for chunk {i+1} due to insufficient samples")
869
+ final_segment += current_segment
870
+ continue
871
+ blended_samples = torch.zeros(2, num_samples, dtype=torch.float32)
872
+ prev_samples = prev_audio[:, :num_samples]
873
+ curr_samples = curr_audio[:, :num_samples]
874
+ hann_window = torch.hann_window(num_samples, periodic=False)
875
+ fade_out = hann_window.flip(0)
876
+ fade_in = hann_window
877
+ blended_samples = (prev_samples * fade_out + curr_samples * fade_in)
878
+ blended_samples = (blended_samples * (2**23 if sample_width == 3 else 32767)).to(torch.int32 if sample_width == 3 else torch.int16)
879
+ temp_crossfade_path = f"temp_crossfade_{int(time.time()*1000)}.wav"
880
+ torchaudio.save(temp_crossfade_path, blended_samples, processing_sample_rate, bits_per_sample=bit_depth_int)
881
+ blended_segment = AudioSegment.from_wav(temp_crossfade_path)
882
+ blended_segment = ensure_stereo(blended_segment, processing_sample_rate, sample_width)
883
+ blended_segment = rms_normalize(blended_segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate)
884
+ final_segment = final_segment[:-overlap_ms] + blended_segment + current_segment[overlap_ms:]
885
+ finally:
886
+ for temp_path in [prev_wav_path, curr_wav_path, temp_crossfade_path]:
887
+ try:
888
+ if os.path.exists(temp_path):
889
+ os.remove(temp_path)
890
+ logger.debug(f"Deleted temporary file {temp_path}")
891
+ except OSError:
892
+ logger.warning(f"Failed to delete temporary file {temp_path}")
893
+ else:
894
+ logger.debug(f"Concatenating chunk {i+1} without crossfade")
895
+ final_segment += current_segment
896
+
897
+ final_segment = final_segment[:total_duration * 1000]
898
+ logger.info("Post-processing final track...")
899
+ final_segment = apply_noise_gate(final_segment, threshold_db=-80, sample_rate=processing_sample_rate)
900
+ final_segment = balance_stereo(final_segment, noise_threshold=-40, sample_rate=processing_sample_rate)
901
+ final_segment = rms_normalize(final_segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate)
902
+ final_segment = apply_eq(final_segment, sample_rate=processing_sample_rate)
903
+ final_segment = apply_fade(final_segment)
904
+ final_segment = final_segment - 10
905
+ final_segment = final_segment.set_frame_rate(output_sample_rate_int)
906
+
907
+ mp3_path = f"output_adjusted_volume_{int(time.time())}.mp3"
908
+ logger.info("⚠️ WARNING: Audio is set to safe levels (~ -23 dBFS RMS, -3 dBFS peak). Start playback at LOW volume (10-20%) and adjust gradually.")
909
+ logger.info("VERIFY: Open the file in Audacity to check for high-pitched tones and quality. RMS should be ~ -23 dBFS, peaks ≀ -3 dBFS. Report any issues.")
910
+ try:
911
+ clean_memory()
912
+ logger.debug(f"Exporting final audio to {mp3_path} with bitrate {bitrate}, sample rate {output_sample_rate_int} Hz, bit depth {bit_depth_int}-bit")
913
+ final_segment.export(
914
+ mp3_path,
915
+ format="mp3",
916
+ bitrate=bitrate,
917
+ tags={"title": "GhostAI Instrumental", "artist": "GhostAI"}
918
+ )
919
+ logger.info(f"Final audio saved to {mp3_path}")
920
+ except Exception as e:
921
+ logger.error(f"Error exporting MP3 with bitrate {bitrate}: {e}")
922
+ logger.error(traceback.format_exc())
923
+ fallback_path = f"fallback_output_{int(time.time())}.mp3"
924
+ try:
925
+ final_segment.export(fallback_path, format="mp3", bitrate="128k")
926
+ logger.info(f"Final audio saved to fallback: {fallback_path} with 128 kbps")
927
+ mp3_path = fallback_path
928
+ except Exception as fallback_e:
929
+ logger.error(f"Failed to save fallback MP3: {fallback_e}")
930
+ return None, f"❌ Failed to export audio: {fallback_e}", vram_status
931
+
932
+ vram_status = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
933
+ logger.info(f"Generation completed in {time.time() - start_time:.2f} seconds")
934
+ return mp3_path, "βœ… Done! Generated track with adjusted volume levels. Check for quality in Audacity.", vram_status
935
+ except Exception as e:
936
+ logger.error(f"Failed to combine audio chunks: {e}")
937
+ logger.error(traceback.format_exc())
938
+ return None, f"❌ Failed to combine audio: {e}", vram_status
939
+ except Exception as e:
940
+ logger.error(f"Generation failed: {e}")
941
+ logger.error(traceback.format_exc())
942
+ return None, f"❌ Generation failed: {e}", vram_status
943
+ finally:
944
+ clean_memory()
945
+
946
+ # Clear inputs function
947
+ def clear_inputs():
948
+ logger.info("Clearing input fields")
949
+ return "", 5.8, 18, 0.88, 0.15, 30, 120, "none", "none", "none", "none", "none", -23.0, "default", 1300, "128k", "44100", "16"
950
+
951
+ # Custom CSS with high-contrast colors and green border on active selection
952
+ css = """
953
+ body {
954
+ background: #121212;
955
+ color: #E6E6E6;
956
+ font-family: 'Arial', sans-serif;
957
+ }
958
+ .header-container {
959
+ text-align: center;
960
+ padding: 15px 20px;
961
+ background: #1E1E1E;
962
+ border-bottom: 2px solid #00C853;
963
+ }
964
+ #ghost-logo {
965
+ font-size: 48px;
966
+ color: #00C853;
967
+ }
968
+ h1 {
969
+ color: #FFD600;
970
+ font-size: 28px;
971
+ font-weight: bold;
972
+ }
973
+ h3 {
974
+ color: #FFD600;
975
+ font-size: 20px;
976
+ font-weight: bold;
977
+ }
978
+ p {
979
+ color: #B0BEC5;
980
+ font-size: 14px;
981
+ }
982
+ .input-container, .settings-container, .output-container, .logs-container {
983
+ max-width: 1200px;
984
+ margin: 20px auto;
985
+ padding: 20px;
986
+ background: #212121;
987
+ border: 1px solid #424242;
988
+ border-radius: 8px;
989
+ }
990
+ .textbox {
991
+ background: #2C2C2C;
992
+ border: 1px solid #B0BEC5;
993
+ color: #E6E6E6;
994
+ font-size: 16px;
995
+ }
996
+ .genre-buttons, .bitrate-buttons, .sample-rate-buttons, .bit-depth-buttons {
997
+ display: flex;
998
+ justify-content: center;
999
+ flex-wrap: wrap;
1000
+ gap: 10px;
1001
+ }
1002
+ .genre-btn, .bitrate-btn, .sample-rate-btn, .bit-depth-btn, button {
1003
+ background: #0288D1;
1004
+ border: 2px solid transparent;
1005
+ color: #FFFFFF;
1006
+ padding: 10px 20px;
1007
+ border-radius: 5px;
1008
+ font-size: 16px;
1009
+ transition: all 0.3s ease;
1010
+ }
1011
+ button:hover {
1012
+ background: #03A9F4;
1013
+ cursor: pointer;
1014
+ }
1015
+ button:active, .genre-btn.active, .bitrate-btn.active, .sample-rate-btn.active, .bit-depth-btn.active {
1016
+ border: 2px solid #00C853 !important;
1017
+ background: #01579B;
1018
+ color: #FFFFFF;
1019
+ }
1020
+ .gradio-container {
1021
+ padding: 20px;
1022
+ }
1023
+ .group-container {
1024
+ margin-bottom: 20px;
1025
+ padding: 15px;
1026
+ border: 1px solid #424242;
1027
+ border-radius: 8px;
1028
+ }
1029
+ .slider-label, .dropdown-label {
1030
+ color: #FFD600;
1031
+ font-size: 16px;
1032
+ font-weight: bold;
1033
+ }
1034
+ .slider, .dropdown {
1035
+ background: #2C2C2C;
1036
+ color: #E6E6E6;
1037
+ }
1038
+ .output-container label, .logs-container label {
1039
+ color: #FFD600;
1040
+ font-size: 16px;
1041
+ font-weight: bold;
1042
+ }
1043
+ """
1044
+
1045
+ # Build Gradio interface with updated visuals and default preset
1046
+ logger.info("Building Gradio interface...")
1047
+ with gr.Blocks(css=css) as demo:
1048
+ gr.Markdown("""
1049
+ <div class="header-container">
1050
+ <div id="ghost-logo">πŸ‘»</div>
1051
+ <h1>GhostAI Music Generator 🎹</h1>
1052
+ <p>Create Instrumental Tracks with Ease</p>
1053
+ </div>
1054
+ """)
1055
+
1056
+ with gr.Column(elem_classes="input-container"):
1057
+ gr.Markdown("### 🎸 Prompt Settings")
1058
+ instrumental_prompt = gr.Textbox(
1059
+ label="Instrumental Prompt ✍️",
1060
+ placeholder="Click a genre button or type your own instrumental prompt",
1061
+ lines=4,
1062
+ elem_classes="textbox"
1063
+ )
1064
+ with gr.Row(elem_classes="genre-buttons"):
1065
+ rhcp_btn = gr.Button("Red Hot Chili Peppers 🌢️", elem_classes="genre-btn")
1066
+ nirvana_btn = gr.Button("Nirvana Grunge 🎸", elem_classes="genre-btn")
1067
+ pearl_jam_btn = gr.Button("Pearl Jam Grunge πŸ¦ͺ", elem_classes="genre-btn")
1068
+ soundgarden_btn = gr.Button("Soundgarden Grunge πŸŒ‘", elem_classes="genre-btn")
1069
+ foo_fighters_btn = gr.Button("Foo Fighters 🀘", elem_classes="genre-btn")
1070
+ smashing_pumpkins_btn = gr.Button("Smashing Pumpkins πŸŽƒ", elem_classes="genre-btn")
1071
+ radiohead_btn = gr.Button("Radiohead 🧠", elem_classes="genre-btn")
1072
+ classic_rock_btn = gr.Button("Metallica Heavy Metal 🎸", elem_classes="genre-btn")
1073
+ alternative_rock_btn = gr.Button("Alternative Rock 🎡", elem_classes="genre-btn")
1074
+ post_punk_btn = gr.Button("Post-Punk πŸ–€", elem_classes="genre-btn")
1075
+ indie_rock_btn = gr.Button("Indie Rock 🎀", elem_classes="genre-btn")
1076
+ funk_rock_btn = gr.Button("Funk Rock πŸ•Ί", elem_classes="genre-btn")
1077
+ detroit_techno_btn = gr.Button("Detroit Techno πŸŽ›οΈ", elem_classes="genre-btn")
1078
+ deep_house_btn = gr.Button("Deep House 🏠", elem_classes="genre-btn")
1079
+
1080
+ with gr.Column(elem_classes="settings-container"):
1081
+ gr.Markdown("### βš™οΈ API Settings")
1082
+ with gr.Group(elem_classes="group-container"):
1083
+ cfg_scale = gr.Slider(
1084
+ label="CFG Scale 🎯",
1085
+ minimum=1.0,
1086
+ maximum=10.0,
1087
+ value=5.8,
1088
+ step=0.1,
1089
+ info="Controls how closely the music follows the prompt."
1090
+ )
1091
+ top_k = gr.Slider(
1092
+ label="Top-K Sampling πŸ”’",
1093
+ minimum=10,
1094
+ maximum=500,
1095
+ value=18,
1096
+ step=10,
1097
+ info="Limits sampling to the top k most likely tokens."
1098
+ )
1099
+ top_p = gr.Slider(
1100
+ label="Top-P Sampling 🎰",
1101
+ minimum=0.0,
1102
+ maximum=1.0,
1103
+ value=0.88,
1104
+ step=0.05,
1105
+ info="Keeps tokens with cumulative probability above p."
1106
+ )
1107
+ temperature = gr.Slider(
1108
+ label="Temperature πŸ”₯",
1109
+ minimum=0.1,
1110
+ maximum=2.0,
1111
+ value=0.15,
1112
+ step=0.1,
1113
+ info="Controls randomness; lower values reduce noise."
1114
+ )
1115
+ total_duration = gr.Dropdown(
1116
+ label="Song Length ⏳ (seconds)",
1117
+ choices=[30, 60, 90, 120],
1118
+ value=30,
1119
+ info="Select the total duration of the track."
1120
+ )
1121
+ bpm = gr.Slider(
1122
+ label="Tempo 🎡 (BPM)",
1123
+ minimum=60,
1124
+ maximum=180,
1125
+ value=120,
1126
+ step=1,
1127
+ info="Beats per minute to set the track's tempo."
1128
+ )
1129
+ drum_beat = gr.Dropdown(
1130
+ label="Drum Beat πŸ₯",
1131
+ choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"],
1132
+ value="none",
1133
+ info="Select a drum beat style to influence the rhythm."
1134
+ )
1135
+ synthesizer = gr.Dropdown(
1136
+ label="Synthesizer 🎹",
1137
+ choices=["none", "analog synth", "digital pad", "arpeggiated synth"],
1138
+ value="none",
1139
+ info="Select a synthesizer style for electronic accents."
1140
+ )
1141
+ rhythmic_steps = gr.Dropdown(
1142
+ label="Rhythmic Steps πŸ‘£",
1143
+ choices=["none", "syncopated steps", "steady steps", "complex steps"],
1144
+ value="none",
1145
+ info="Select a rhythmic step style to enhance the beat."
1146
+ )
1147
+ bass_style = gr.Dropdown(
1148
+ label="Bass Style 🎸",
1149
+ choices=["none", "slap bass", "deep bass", "melodic bass"],
1150
+ value="none",
1151
+ info="Select a bass style to shape the low end."
1152
+ )
1153
+ guitar_style = gr.Dropdown(
1154
+ label="Guitar Style 🎸",
1155
+ choices=["none", "distorted", "clean", "jangle"],
1156
+ value="none",
1157
+ info="Select a guitar style to define the riffs."
1158
+ )
1159
+ target_volume = gr.Slider(
1160
+ label="Target Volume 🎚️ (dBFS RMS)",
1161
+ minimum=-30.0,
1162
+ maximum=-20.0,
1163
+ value=-23.0,
1164
+ step=1.0,
1165
+ info="Adjust output loudness (-23 dBFS is standard, -20 dBFS is louder, -30 dBFS is quieter)."
1166
+ )
1167
+ preset = gr.Dropdown(
1168
+ label="Preset Configuration πŸŽ›οΈ",
1169
+ choices=["default", "rock", "techno", "grunge", "indie", "funk_rock"],
1170
+ value="default",
1171
+ info="Select a preset optimized for specific genres."
1172
+ )
1173
+ max_steps = gr.Dropdown(
1174
+ label="Max Steps per Chunk πŸ“",
1175
+ choices=[1000, 1200, 1300, 1500],
1176
+ value=1300,
1177
+ info="Number of generation steps per chunk (1300=~26s, extended to 30s)."
1178
+ )
1179
+ bitrate_state = gr.State(value="128k")
1180
+ sample_rate_state = gr.State(value="44100")
1181
+ bit_depth_state = gr.State(value="16")
1182
+ with gr.Row(elem_classes="bitrate-buttons"):
1183
+ bitrate_128_btn = gr.Button("Set Bitrate to 128 kbps", elem_classes="bitrate-btn")
1184
+ bitrate_192_btn = gr.Button("Set Bitrate to 192 kbps", elem_classes="bitrate-btn")
1185
+ bitrate_320_btn = gr.Button("Set Bitrate to 320 kbps", elem_classes="bitrate-btn")
1186
+ with gr.Row(elem_classes="sample-rate-buttons"):
1187
+ sample_rate_22050_btn = gr.Button("Set Sampling Rate to 22.05 kHz", elem_classes="sample-rate-btn")
1188
+ sample_rate_44100_btn = gr.Button("Set Sampling Rate to 44.1 kHz", elem_classes="sample-rate-btn")
1189
+ sample_rate_48000_btn = gr.Button("Set Sampling Rate to 48 kHz", elem_classes="sample-rate-btn")
1190
+ with gr.Row(elem_classes="bit-depth-buttons"):
1191
+ bit_depth_16_btn = gr.Button("Set Bit Depth to 16-bit", elem_classes="bit-depth-btn")
1192
+ bit_depth_24_btn = gr.Button("Set Bit Depth to 24-bit", elem_classes="bit-depth-btn")
1193
+
1194
+ with gr.Row(elem_classes="action-buttons"):
1195
+ gen_btn = gr.Button("Generate Music πŸš€")
1196
+ clr_btn = gr.Button("Clear Inputs 🧹")
1197
+
1198
+ with gr.Column(elem_classes="output-container"):
1199
+ gr.Markdown("### 🎧 Output")
1200
+ out_audio = gr.Audio(label="Generated Instrumental Track 🎡", type="filepath")
1201
+ status = gr.Textbox(label="Status πŸ“’", interactive=False)
1202
+ vram_status = gr.Textbox(label="VRAM Usage πŸ“Š", interactive=False, value="")
1203
+
1204
+ with gr.Column(elem_classes="logs-container"):
1205
+ gr.Markdown("### πŸ“œ Logs")
1206
+ log_output = gr.Textbox(label="Last Log File Contents", lines=20, interactive=False)
1207
+ log_btn = gr.Button("View Last Log πŸ“‹")
1208
+
1209
+ # Add JavaScript to handle button selection visuals
1210
+ def update_button_styles(selected_button):
1211
+ buttons = [
1212
+ "rhcp_btn", "nirvana_btn", "pearl_jam_btn", "soundgarden_btn", "foo_fighters_btn",
1213
+ "smashing_pumpkins_btn", "radiohead_btn", "classic_rock_btn", "alternative_rock_btn",
1214
+ "post_punk_btn", "indie_rock_btn", "funk_rock_btn", "detroit_techno_btn", "deep_house_btn",
1215
+ "bitrate_128_btn", "bitrate_192_btn", "bitrate_320_btn",
1216
+ "sample_rate_22050_btn", "sample_rate_44100_btn", "sample_rate_48000_btn",
1217
+ "bit_depth_16_btn", "bit_depth_24_btn"
1218
+ ]
1219
+ script = """
1220
+ <script>
1221
+ document.querySelectorAll('.genre-btn, .bitrate-btn, .sample-rate-btn, .bit-depth-btn').forEach(btn => {
1222
+ btn.classList.remove('active');
1223
+ });
1224
+ document.querySelector('#""" + selected_button + """').classList.add('active');
1225
+ </script>
1226
+ """
1227
+ return script
1228
+
1229
+ rhcp_btn.click(set_red_hot_chili_peppers_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, gr.State(value=1)], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("rhcp_btn"))
1230
+ nirvana_btn.click(set_nirvana_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("nirvana_btn"))
1231
+ pearl_jam_btn.click(set_pearl_jam_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("pearl_jam_btn"))
1232
+ soundgarden_btn.click(set_soundgarden_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("soundgarden_btn"))
1233
+ foo_fighters_btn.click(set_foo_fighters_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("foo_fighters_btn"))
1234
+ smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("smashing_pumpkins_btn"))
1235
+ radiohead_btn.click(set_radiohead_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("radiohead_btn"))
1236
+ classic_rock_btn.click(set_classic_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("classic_rock_btn"))
1237
+ alternative_rock_btn.click(set_alternative_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("alternative_rock_btn"))
1238
+ post_punk_btn.click(set_post_punk_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("post_punk_btn"))
1239
+ indie_rock_btn.click(set_indie_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("indie_rock_btn"))
1240
+ funk_rock_btn.click(set_funk_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("funk_rock_btn"))
1241
+ detroit_techno_btn.click(set_detroit_techno_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("detroit_techno_btn"))
1242
+ deep_house_btn.click(set_deep_house_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("deep_house_btn"))
1243
+ bitrate_128_btn.click(set_bitrate_128, inputs=None, outputs=bitrate_state).then(None, None, None, js=update_button_styles("bitrate_128_btn"))
1244
+ bitrate_192_btn.click(set_bitrate_192, inputs=None, outputs=bitrate_state).then(None, None, None, js=update_button_styles("bitrate_192_btn"))
1245
+ bitrate_320_btn.click(set_bitrate_320, inputs=None, outputs=bitrate_state).then(None, None, None, js=update_button_styles("bitrate_320_btn"))
1246
+ sample_rate_22050_btn.click(set_sample_rate_22050, inputs=None, outputs=sample_rate_state).then(None, None, None, js=update_button_styles("sample_rate_22050_btn"))
1247
+ sample_rate_44100_btn.click(set_sample_rate_44100, inputs=None, outputs=sample_rate_state).then(None, None, None, js=update_button_styles("sample_rate_44100_btn"))
1248
+ sample_rate_48000_btn.click(set_sample_rate_48000, inputs=None, outputs=sample_rate_state).then(None, None, None, js=update_button_styles("sample_rate_48000_btn"))
1249
+ bit_depth_16_btn.click(set_bit_depth_16, inputs=None, outputs=bit_depth_state).then(None, None, None, js=update_button_styles("bit_depth_16_btn"))
1250
+ bit_depth_24_btn.click(set_bit_depth_24, inputs=None, outputs=bit_depth_state).then(None, None, None, js=update_button_styles("bit_depth_24_btn"))
1251
+ gen_btn.click(
1252
+ generate_music_wrapper,
1253
+ inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, vram_status, bitrate_state, sample_rate_state, bit_depth_state],
1254
+ outputs=[out_audio, status, vram_status]
1255
+ )
1256
+ clr_btn.click(
1257
+ clear_inputs,
1258
+ inputs=None,
1259
+ outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state]
1260
+ )
1261
+ log_btn.click(
1262
+ get_latest_log,
1263
+ inputs=None,
1264
+ outputs=log_output
1265
+ )
1266
+
1267
+ # Launch locally without OpenAPI/docs
1268
+ logger.info("Launching Gradio UI at http://localhost:9999...")
1269
+ try:
1270
+ app = demo.launch(
1271
+ server_name="0.0.0.0",
1272
+ server_port=9999,
1273
+ share=True,
1274
+ inbrowser=False,
1275
+ show_error=True
1276
+ )
1277
+ except Exception as e:
1278
+ logger.error(f"Failed to launch Gradio UI: {e}")
1279
+ logger.error(traceback.format_exc())
1280
+ sys.exit(1)