|  | import os | 
					
						
						|  | import torch | 
					
						
						|  | import torchaudio | 
					
						
						|  | import time | 
					
						
						|  | import sys | 
					
						
						|  | import numpy as np | 
					
						
						|  | import gc | 
					
						
						|  | import gradio as gr | 
					
						
						|  | from pydub import AudioSegment | 
					
						
						|  | from audiocraft.models import MusicGen | 
					
						
						|  | from torch.cuda.amp import autocast | 
					
						
						|  | import warnings | 
					
						
						|  | import random | 
					
						
						|  | import traceback | 
					
						
						|  | import logging | 
					
						
						|  | from datetime import datetime | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | import mmap | 
					
						
						|  | import subprocess | 
					
						
						|  | import re | 
					
						
						|  | import gradio_client.utils | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | original_get_type = gradio_client.utils.get_type | 
					
						
						|  | def patched_get_type(schema): | 
					
						
						|  | if isinstance(schema, bool): | 
					
						
						|  | return "boolean" | 
					
						
						|  | return original_get_type(schema) | 
					
						
						|  | gradio_client.utils.get_type = patched_get_type | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | warnings.filterwarnings("ignore") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.backends.cudnn.benchmark = False | 
					
						
						|  | torch.backends.cudnn.deterministic = True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | log_dir = "logs" | 
					
						
						|  | os.makedirs(log_dir, exist_ok=True) | 
					
						
						|  | log_file = os.path.join(log_dir, f"musicgen_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | level=logging.DEBUG, | 
					
						
						|  | format="%(asctime)s [%(levelname)s] %(message)s", | 
					
						
						|  | handlers=[ | 
					
						
						|  | logging.FileHandler(log_file), | 
					
						
						|  | logging.StreamHandler(sys.stdout) | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
						
						|  | if device != "cuda": | 
					
						
						|  | logger.error("CUDA is required for GPU rendering. CPU rendering is disabled.") | 
					
						
						|  | sys.exit(1) | 
					
						
						|  | logger.info(f"Using GPU: {torch.cuda.get_device_name(0)} (CUDA 12)") | 
					
						
						|  | logger.info(f"Using precision: float16 for model, float32 for CPU processing") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def clean_memory(): | 
					
						
						|  | try: | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  | gc.collect() | 
					
						
						|  | torch.cuda.ipc_collect() | 
					
						
						|  | torch.cuda.synchronize() | 
					
						
						|  | vram_mb = torch.cuda.memory_allocated() / 1024**2 | 
					
						
						|  | logger.info(f"Memory cleaned: VRAM allocated = {vram_mb:.2f} MB") | 
					
						
						|  | logger.debug(f"VRAM summary: {torch.cuda.memory_summary()}") | 
					
						
						|  | return vram_mb | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to clean memory: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_vram(): | 
					
						
						|  | try: | 
					
						
						|  | result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv'], capture_output=True, text=True) | 
					
						
						|  | lines = result.stdout.splitlines() | 
					
						
						|  | if len(lines) > 1: | 
					
						
						|  | used_mb, total_mb = map(int, re.findall(r'\d+', lines[1])) | 
					
						
						|  | free_mb = total_mb - used_mb | 
					
						
						|  | logger.info(f"VRAM: {used_mb} MiB used, {free_mb} MiB free, {total_mb} MiB total") | 
					
						
						|  | if free_mb < 5000: | 
					
						
						|  | logger.warning(f"Low free VRAM ({free_mb} MiB). Close other applications or processes.") | 
					
						
						|  | result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid,used_memory', '--format=csv'], capture_output=True, text=True) | 
					
						
						|  | logger.info(f"GPU processes:\n{result.stdout}") | 
					
						
						|  | return free_mb | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to check VRAM: {e}") | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | free_vram = check_vram() | 
					
						
						|  | if free_vram is not None and free_vram < 5000: | 
					
						
						|  | logger.warning("Consider terminating high-VRAM processes before continuing.") | 
					
						
						|  | clean_memory() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | logger.info("Loading MusicGen large model into VRAM...") | 
					
						
						|  | local_model_path = "./models/musicgen-large" | 
					
						
						|  | if not os.path.exists(local_model_path): | 
					
						
						|  | logger.error(f"Local model path {local_model_path} does not exist.") | 
					
						
						|  | logger.error("Please download the MusicGen large model weights and place them in the correct directory.") | 
					
						
						|  | sys.exit(1) | 
					
						
						|  | with autocast(dtype=torch.float16): | 
					
						
						|  | musicgen_model = MusicGen.get_pretrained(local_model_path, device=device) | 
					
						
						|  | musicgen_model.set_generation_params( | 
					
						
						|  | duration=30, | 
					
						
						|  | two_step_cfg=False | 
					
						
						|  | ) | 
					
						
						|  | logger.info("MusicGen large model loaded successfully.") | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to load MusicGen model: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | sys.exit(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_disk_space(path="."): | 
					
						
						|  | try: | 
					
						
						|  | stat = os.statvfs(path) | 
					
						
						|  | free_space = stat.f_bavail * stat.f_frsize / (1024**3) | 
					
						
						|  | if free_space < 1.0: | 
					
						
						|  | logger.warning(f"Low disk space ({free_space:.2f} GB). Ensure at least 1 GB free.") | 
					
						
						|  | return free_space >= 1.0 | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to check disk space: {e}") | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def ensure_stereo(audio_segment, sample_rate=48000, sample_width=2): | 
					
						
						|  | """Ensure the audio segment is stereo (2 channels).""" | 
					
						
						|  | try: | 
					
						
						|  | if audio_segment.channels != 2: | 
					
						
						|  | logger.debug(f"Converting to stereo: {audio_segment.channels} channels detected") | 
					
						
						|  | audio_segment = audio_segment.set_channels(2) | 
					
						
						|  | if audio_segment.frame_rate != sample_rate: | 
					
						
						|  | logger.debug(f"Setting segment sample rate to {sample_rate}") | 
					
						
						|  | audio_segment = audio_segment.set_frame_rate(sample_rate) | 
					
						
						|  | return audio_segment | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to ensure stereo: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return audio_segment | 
					
						
						|  |  | 
					
						
						|  | def balance_stereo(audio_segment, noise_threshold=-40, sample_rate=48000): | 
					
						
						|  | logger.debug(f"Balancing stereo for segment with sample rate {sample_rate}") | 
					
						
						|  | try: | 
					
						
						|  | audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) | 
					
						
						|  | samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) | 
					
						
						|  | if audio_segment.channels == 2: | 
					
						
						|  | stereo_samples = samples.reshape(-1, 2) | 
					
						
						|  | db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10) | 
					
						
						|  | mask = db_samples > noise_threshold | 
					
						
						|  | stereo_samples = stereo_samples * mask | 
					
						
						|  | left_nonzero = stereo_samples[:, 0][stereo_samples[:, 0] != 0] | 
					
						
						|  | right_nonzero = stereo_samples[:, 1][stereo_samples[:, 1] != 0] | 
					
						
						|  | left_rms = np.sqrt(np.mean(left_nonzero**2)) if len(left_nonzero) > 0 else 0 | 
					
						
						|  | right_rms = np.sqrt(np.mean(right_nonzero**2)) if len(right_nonzero) > 0 else 0 | 
					
						
						|  | if left_rms > 0 and right_rms > 0: | 
					
						
						|  | avg_rms = (left_rms + right_rms) / 2 | 
					
						
						|  | stereo_samples[:, 0] = stereo_samples[:, 0] * (avg_rms / left_rms) | 
					
						
						|  | stereo_samples[:, 1] = stereo_samples[:, 1] * (avg_rms / right_rms) | 
					
						
						|  | balanced_samples = stereo_samples.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16) | 
					
						
						|  | if len(balanced_samples) % 2 != 0: | 
					
						
						|  | balanced_samples = balanced_samples[:-1] | 
					
						
						|  | balanced_segment = AudioSegment( | 
					
						
						|  | balanced_samples.tobytes(), | 
					
						
						|  | frame_rate=sample_rate, | 
					
						
						|  | sample_width=audio_segment.sample_width, | 
					
						
						|  | channels=2 | 
					
						
						|  | ) | 
					
						
						|  | logger.debug("Stereo balancing completed") | 
					
						
						|  | return balanced_segment | 
					
						
						|  | logger.error("Failed to ensure stereo channels") | 
					
						
						|  | return audio_segment | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to balance stereo: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return audio_segment | 
					
						
						|  |  | 
					
						
						|  | def calculate_rms(segment): | 
					
						
						|  | try: | 
					
						
						|  | samples = np.array(segment.get_array_of_samples(), dtype=np.float32) | 
					
						
						|  | rms = np.sqrt(np.mean(samples**2)) | 
					
						
						|  | logger.debug(f"Calculated RMS: {rms}") | 
					
						
						|  | return rms | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to calculate RMS: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return 0 | 
					
						
						|  |  | 
					
						
						|  | def rms_normalize(segment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=48000): | 
					
						
						|  | logger.debug(f"Normalizing RMS for segment with target {target_rms_db} dBFS") | 
					
						
						|  | try: | 
					
						
						|  | segment = ensure_stereo(segment, sample_rate, segment.sample_width) | 
					
						
						|  | target_rms = 10 ** (target_rms_db / 20) * (2**23 if segment.sample_width == 3 else 32767) | 
					
						
						|  | current_rms = calculate_rms(segment) | 
					
						
						|  | if current_rms > 0: | 
					
						
						|  | gain_factor = target_rms / current_rms | 
					
						
						|  | segment = segment.apply_gain(20 * np.log10(gain_factor)) | 
					
						
						|  | segment = hard_limit(segment, limit_db=peak_limit_db, sample_rate=sample_rate) | 
					
						
						|  | logger.debug("RMS normalization completed") | 
					
						
						|  | return segment | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to normalize RMS: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return segment | 
					
						
						|  |  | 
					
						
						|  | def hard_limit(audio_segment, limit_db=-3.0, sample_rate=48000): | 
					
						
						|  | logger.debug(f"Applying hard limit at {limit_db} dBFS") | 
					
						
						|  | try: | 
					
						
						|  | audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) | 
					
						
						|  | limit = 10 ** (limit_db / 20.0) * (2**23 if audio_segment.sample_width == 3 else 32767) | 
					
						
						|  | samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) | 
					
						
						|  | samples = np.clip(samples, -limit, limit).astype(np.int32 if audio_segment.sample_width == 3 else np.int16) | 
					
						
						|  | if len(samples) % 2 != 0: | 
					
						
						|  | samples = samples[:-1] | 
					
						
						|  | limited_segment = AudioSegment( | 
					
						
						|  | samples.tobytes(), | 
					
						
						|  | frame_rate=sample_rate, | 
					
						
						|  | sample_width=audio_segment.sample_width, | 
					
						
						|  | channels=2 | 
					
						
						|  | ) | 
					
						
						|  | logger.debug("Hard limit applied") | 
					
						
						|  | return limited_segment | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to apply hard limit: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return audio_segment | 
					
						
						|  |  | 
					
						
						|  | def apply_noise_gate(audio_segment, threshold_db=-80, sample_rate=48000): | 
					
						
						|  | logger.debug(f"Applying noise gate with threshold {threshold_db} dBFS") | 
					
						
						|  | try: | 
					
						
						|  | audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width) | 
					
						
						|  | samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) | 
					
						
						|  | if audio_segment.channels == 2: | 
					
						
						|  | stereo_samples = samples.reshape(-1, 2) | 
					
						
						|  | db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10) | 
					
						
						|  | mask = db_samples > threshold_db | 
					
						
						|  | stereo_samples = stereo_samples * mask | 
					
						
						|  |  | 
					
						
						|  | db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10) | 
					
						
						|  | mask = db_samples > threshold_db | 
					
						
						|  | stereo_samples = stereo_samples * mask | 
					
						
						|  | gated_samples = stereo_samples.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16) | 
					
						
						|  | if len(gated_samples) % 2 != 0: | 
					
						
						|  | gated_samples = gated_samples[:-1] | 
					
						
						|  | gated_segment = AudioSegment( | 
					
						
						|  | gated_samples.tobytes(), | 
					
						
						|  | frame_rate=sample_rate, | 
					
						
						|  | sample_width=audio_segment.sample_width, | 
					
						
						|  | channels=2 | 
					
						
						|  | ) | 
					
						
						|  | logger.debug("Noise gate applied") | 
					
						
						|  | return gated_segment | 
					
						
						|  | logger.error("Failed to ensure stereo channels for noise gate") | 
					
						
						|  | return audio_segment | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to apply noise gate: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return audio_segment | 
					
						
						|  |  | 
					
						
						|  | def apply_eq(segment, sample_rate=48000): | 
					
						
						|  | logger.debug(f"Applying EQ with sample rate {sample_rate}") | 
					
						
						|  | try: | 
					
						
						|  | segment = ensure_stereo(segment, sample_rate, segment.sample_width) | 
					
						
						|  |  | 
					
						
						|  | segment = segment.high_pass_filter(20) | 
					
						
						|  |  | 
					
						
						|  | segment = segment.low_pass_filter(8000) | 
					
						
						|  |  | 
					
						
						|  | segment = segment - 3 | 
					
						
						|  |  | 
					
						
						|  | segment = segment - 3 | 
					
						
						|  |  | 
					
						
						|  | segment = segment - 10 | 
					
						
						|  | logger.debug("EQ applied: 8 kHz low-pass, 3 dB reduction at 1-8 kHz, 3 dB notch at 12 kHz, 10 dB high-shelf above 5 kHz") | 
					
						
						|  | return segment | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to apply EQ: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return segment | 
					
						
						|  |  | 
					
						
						|  | def apply_fade(segment, fade_in_duration=500, fade_out_duration=500): | 
					
						
						|  | logger.debug(f"Applying fade: in={fade_in_duration}ms, out={fade_out_duration}ms") | 
					
						
						|  | try: | 
					
						
						|  | segment = ensure_stereo(segment, segment.frame_rate, segment.sample_width) | 
					
						
						|  | segment = segment.fade_in(fade_in_duration) | 
					
						
						|  | segment = segment.fade_out(fade_out_duration) | 
					
						
						|  | logger.debug("Fade applied") | 
					
						
						|  | return segment | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to apply fade: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return segment | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num): | 
					
						
						|  | try: | 
					
						
						|  | bpm_range = (90, 130) | 
					
						
						|  | bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm | 
					
						
						|  | drum = f", standard rock drums with occasional funk grooves and dynamic fills" if drum_beat == "none" else f", {drum_beat} drums" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", funky bass lines with slap technique and melodic variation" if bass_style == "none" else f", {bass_style} bass" | 
					
						
						|  | guitar = f", energetic guitar riffs with punk rock energy and tonal shifts" if guitar_style == "none" else f", {guitar_style} guitar" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | base_prompt = ( | 
					
						
						|  | f"Instrumental alternative rock by Red Hot Chili Peppers{guitar}{bass}{drum}{synth}, blending funk rock and rap rock elements, " | 
					
						
						|  | f"capturing the raw energy of early 90s rock with dynamic variation to avoid monotony at {bpm} BPM" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if chunk_num == 1: | 
					
						
						|  | prompt = base_prompt + ", featuring a dynamic intro and expressive verse with a mix of upbeat and introspective tones." | 
					
						
						|  | else: | 
					
						
						|  | prompt = base_prompt + ", featuring a powerful chorus and energetic outro with heightened intensity and drive." | 
					
						
						|  |  | 
					
						
						|  | logger.debug(f"Generated RHCP prompt for chunk {chunk_num}: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate RHCP prompt for chunk {chunk_num}: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_nirvana_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | bpm_range = (100, 130) | 
					
						
						|  | bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm | 
					
						
						|  | drum = f", standard rock drums, punk energy" if drum_beat == "none" else f", {drum_beat} drums, punk energy" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | chosen_bass = random.choice(['deep bass', 'melodic bass']) if bass_style == "none" else bass_style | 
					
						
						|  | bass = f", {chosen_bass}" | 
					
						
						|  | chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style | 
					
						
						|  | guitar = f", {chosen_guitar}" | 
					
						
						|  | chosen_rhythm = random.choice(['steady steps', 'dynamic shifts']) if rhythmic_steps == "none" else rhythmic_steps | 
					
						
						|  | rhythm = f", {chosen_rhythm}" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental grunge by Nirvana{guitar}{bass}{drum}{synth}, raw lo-fi production, emotional rawness{rhythm} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Nirvana prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Nirvana prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_pearl_jam_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | bpm_range = (100, 140) | 
					
						
						|  | bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm | 
					
						
						|  | drum = f", standard rock drums, driving rhythm" if drum_beat == "none" else f", {drum_beat} drums, driving rhythm" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", melodic bass, emotional tone" if bass_style == "none" else f", {bass_style}, emotional tone" | 
					
						
						|  | chosen_guitar = random.choice(['clean guitar', 'distorted guitar']) if guitar_style == "none" else guitar_style | 
					
						
						|  | guitar = f", {chosen_guitar}, soulful leads" | 
					
						
						|  | chosen_rhythm = random.choice(['steady steps', 'syncopated steps']) if rhythmic_steps == "none" else rhythmic_steps | 
					
						
						|  | rhythm = f", {chosen_rhythm}" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental grunge by Pearl Jam{guitar}{bass}{drum}{synth}, classic rock influences, narrative depth{rhythm} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Pearl Jam prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Pearl Jam prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_soundgarden_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | bpm_range = (90, 140) | 
					
						
						|  | bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm | 
					
						
						|  | drum = f", standard rock drums, heavy rhythm" if drum_beat == "none" else f", {drum_beat} drums, heavy rhythm" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", deep bass, sludgy tone" if bass_style == "none" else f", {bass_style}, sludgy tone" | 
					
						
						|  | guitar = f", distorted guitar, downtuned riffs, psychedelic vibe" if guitar_style == "none" else f", {guitar_style}, downtuned riffs, psychedelic vibe" | 
					
						
						|  | rhythm = f", complex steps" if rhythmic_steps == "none" else f", {rhythmic_steps}" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental grunge with heavy metal influences by Soundgarden{guitar}{bass}{drum}{synth}, vocal-driven melody, experimental time signatures{rhythm} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Soundgarden prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Soundgarden prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_foo_fighters_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | bpm_range = (110, 150) | 
					
						
						|  | bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm | 
					
						
						|  | drum = f", standard rock drums, powerful drive" if drum_beat == "none" else f", {drum_beat} drums, powerful drive" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", melodic bass, supportive tone" if bass_style == "none" else f", {bass_style}, supportive tone" | 
					
						
						|  | chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style | 
					
						
						|  | guitar = f", {chosen_guitar}, anthemic quality" | 
					
						
						|  | chosen_rhythm = random.choice(['steady steps', 'driving rhythm']) if rhythmic_steps == "none" else rhythmic_steps | 
					
						
						|  | rhythm = f", {chosen_rhythm}" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental alternative rock with post-grunge influences by Foo Fighters{guitar}, stadium-ready hooks{bass}{drum}{synth}, Grohlβs raw energy{rhythm} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Foo Fighters prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Foo Fighters prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_classic_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | bpm_range = (120, 180) | 
					
						
						|  | bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm | 
					
						
						|  | drum = f", double bass drums" if drum_beat == "none" else f", {drum_beat} drums" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", aggressive bass" if bass_style == "none" else f", {bass_style}" | 
					
						
						|  | guitar = f", distorted guitar, blazing fast riffs" if guitar_style == "none" else f", {guitar_style}, blazing fast riffs" | 
					
						
						|  | rhythm = f", complex steps" if rhythmic_steps == "none" else f", {rhythmic_steps}" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental thrash metal by Metallica{guitar}{bass}{drum}{synth}, raw intensity{rhythm} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Metallica prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Metallica prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_smashing_pumpkins_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | drum = f", {drum_beat} drums" if drum_beat != "none" else "" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else ", lush synths" | 
					
						
						|  | bass = f", {bass_style} bass" if bass_style == "none" else "" | 
					
						
						|  | guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", dreamy guitar" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental alternative rock by Smashing Pumpkins{guitar}{synth}{drum}{bass} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Smashing Pumpkins prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Smashing Pumpkins prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_radiohead_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | drum = f", {drum_beat} drums" if drum_beat != "none" else "" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else ", atmospheric synths" | 
					
						
						|  | bass = f", {bass_style} bass" if bass_style == "none" else ", hypnotic bass" | 
					
						
						|  | guitar = f", {guitar_style} guitar" if guitar_style != "none" else "" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental experimental rock by Radiohead{synth}{bass}{drum}{guitar} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Radiohead prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Radiohead prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_alternative_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | drum = f", {drum_beat} drums" if drum_beat != "none" else "" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", {bass_style} bass" if bass_style == "none" else ", melodic bass" | 
					
						
						|  | guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", distorted guitar" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental alternative rock by Pixies{guitar}{bass}{drum}{synth} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Alternative Rock prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Alternative Rock prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_post_punk_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | drum = f", {drum_beat} drums" if drum_beat != "none" else ", precise drums" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", {bass_style} bass" if bass_style == "none" else ", driving bass" | 
					
						
						|  | guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental post-punk by Joy Division{guitar}{bass}{drum}{synth} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Post-Punk prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Post-Punk prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_indie_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | drum = f", {drum_beat} drums" if drum_beat != "none" else "" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", {bass_style} bass" if bass_style == "none" else ", groovy bass" | 
					
						
						|  | guitar = f", {guitar_style} guitar" if guitar_style == "none" else ", jangly guitar" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental indie rock by Arctic Monkeys{guitar}{bass}{drum}{synth} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Indie Rock prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Indie Rock prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_funk_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | drum = f", {drum_beat} drums" if drum_beat != "none" else ", heavy drums" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else "" | 
					
						
						|  | bass = f", {bass_style} bass" if bass_style == "none" else ", slap bass" | 
					
						
						|  | guitar = f", {guitar_style} guitar" if guitar_style == "none" else ", funky guitar" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental funk rock by Rage Against the Machine{guitar}{bass}{drum}{synth} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Funk Rock prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Funk Rock prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_detroit_techno_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | drum = f", {drum_beat} drums" if drum_beat != "none" else ", four-on-the-floor drums" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else ", pulsing synths" | 
					
						
						|  | bass = f", {bass_style} bass" if bass_style == "none" else ", driving bass" | 
					
						
						|  | guitar = f", {guitar_style} guitar" if guitar_style == "none" else "" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental Detroit techno by Juan Atkins{synth}{bass}{drum}{guitar} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Detroit Techno prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Detroit Techno prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | def set_deep_house_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): | 
					
						
						|  | try: | 
					
						
						|  | drum = f", {drum_beat} drums" if drum_beat == "none" else ", steady kick drums" | 
					
						
						|  | synth = f", {synthesizer}" if synthesizer != "none" else ", warm synths" | 
					
						
						|  | bass = f", {bass_style} bass" if bass_style == "none" else ", deep bass" | 
					
						
						|  | guitar = f", {guitar_style} guitar" if guitar_style == "none" else "" | 
					
						
						|  | prompt = ( | 
					
						
						|  | f"Instrumental deep house by Larry Heard{synth}{bass}{drum}{guitar} at {bpm} BPM." | 
					
						
						|  | ) | 
					
						
						|  | logger.debug(f"Generated Deep House prompt: {prompt}") | 
					
						
						|  | return prompt | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to generate Deep House prompt: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | PRESETS = { | 
					
						
						|  | "default": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15}, | 
					
						
						|  | "rock": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15}, | 
					
						
						|  | "techno": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15}, | 
					
						
						|  | "grunge": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15}, | 
					
						
						|  | "indie": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15}, | 
					
						
						|  | "funk_rock": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15} | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_latest_log(): | 
					
						
						|  | try: | 
					
						
						|  | log_files = sorted(Path(log_dir).glob("musicgen_log_*.log"), key=os.path.getmtime, reverse=True) | 
					
						
						|  | if not log_files: | 
					
						
						|  | logger.warning("No log files found") | 
					
						
						|  | return "No log files found." | 
					
						
						|  | with open(log_files[0], "r") as f: | 
					
						
						|  | content = f.read() | 
					
						
						|  | logger.info(f"Retrieved latest log file: {log_files[0]}") | 
					
						
						|  | return content | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to read log file: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return f"Error reading log file: {e}" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_bitrate_128(): | 
					
						
						|  | logger.info("Bitrate set to 128 kbps") | 
					
						
						|  | return "128k" | 
					
						
						|  |  | 
					
						
						|  | def set_bitrate_192(): | 
					
						
						|  | logger.info("Bitrate set to 192 kbps") | 
					
						
						|  | return "192k" | 
					
						
						|  |  | 
					
						
						|  | def set_bitrate_320(): | 
					
						
						|  | logger.info("Bitrate set to 320 kbps") | 
					
						
						|  | return "320k" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_sample_rate_22050(): | 
					
						
						|  | logger.info("Output sampling rate set to 22.05 kHz") | 
					
						
						|  | return "22050" | 
					
						
						|  |  | 
					
						
						|  | def set_sample_rate_44100(): | 
					
						
						|  | logger.info("Output sampling rate set to 44.1 kHz") | 
					
						
						|  | return "44100" | 
					
						
						|  |  | 
					
						
						|  | def set_sample_rate_48000(): | 
					
						
						|  | logger.info("Output sampling rate set to 48 kHz") | 
					
						
						|  | return "48000" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def set_bit_depth_16(): | 
					
						
						|  | logger.info("Bit depth set to 16-bit") | 
					
						
						|  | return "16" | 
					
						
						|  |  | 
					
						
						|  | def set_bit_depth_24(): | 
					
						
						|  | logger.info("Bit depth set to 24-bit") | 
					
						
						|  | return "24" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate_music_wrapper(*args): | 
					
						
						|  | try: | 
					
						
						|  | result = generate_music(*args) | 
					
						
						|  | return result | 
					
						
						|  | finally: | 
					
						
						|  | clean_memory() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, bpm: int, drum_beat: str, synthesizer: str, rhythmic_steps: str, bass_style: str, guitar_style: str, target_volume: float, preset: str, max_steps: str, vram_status: str, bitrate: str, output_sample_rate: str, bit_depth: str): | 
					
						
						|  | global musicgen_model | 
					
						
						|  | if not instrumental_prompt.strip(): | 
					
						
						|  | logger.warning("Empty instrumental prompt provided") | 
					
						
						|  | return None, "β οΈ Please enter a valid instrumental prompt!", vram_status | 
					
						
						|  | try: | 
					
						
						|  | logger.info("Starting music generation...") | 
					
						
						|  | start_time = time.time() | 
					
						
						|  | clean_memory() | 
					
						
						|  | try: | 
					
						
						|  | max_steps_int = int(max_steps) | 
					
						
						|  | except ValueError: | 
					
						
						|  | logger.error(f"Invalid max_steps value: {max_steps}") | 
					
						
						|  | return None, "β Invalid max_steps value; must be a number (1000, 1200, 1300, or 1500)", vram_status | 
					
						
						|  | try: | 
					
						
						|  | output_sample_rate_int = int(output_sample_rate) | 
					
						
						|  | except ValueError: | 
					
						
						|  | logger.error(f"Invalid output_sample_rate value: {output_sample_rate}") | 
					
						
						|  | return None, "β Invalid output sampling rate; must be a number (22050, 32000, 44100, or 48000)", vram_status | 
					
						
						|  | try: | 
					
						
						|  | bit_depth_int = int(bit_depth) | 
					
						
						|  | sample_width = 3 if bit_depth_int == 24 else 2 | 
					
						
						|  | except ValueError: | 
					
						
						|  | logger.error(f"Invalid bit_depth value: {bit_depth}") | 
					
						
						|  | return None, "β Invalid bit depth; must be 16 or 24", vram_status | 
					
						
						|  | max_duration = min(max_steps_int / 50, 30) | 
					
						
						|  | total_duration = min(max(total_duration, 30), 120) | 
					
						
						|  | processing_sample_rate = 48000 | 
					
						
						|  | channels = 2 | 
					
						
						|  | audio_segments = [] | 
					
						
						|  | overlap_duration = 0.2 | 
					
						
						|  | remaining_duration = total_duration | 
					
						
						|  |  | 
					
						
						|  | if preset != "default": | 
					
						
						|  | preset_params = PRESETS.get(preset, PRESETS["default"]) | 
					
						
						|  | cfg_scale = preset_params["cfg_scale"] | 
					
						
						|  | top_k = preset_params["top_k"] | 
					
						
						|  | top_p = preset_params["top_p"] | 
					
						
						|  | temperature = preset_params["temperature"] | 
					
						
						|  | logger.info(f"Applied preset {preset}: cfg_scale={cfg_scale}, top_k={top_k}, top_p={top_p}, temperature={temperature}") | 
					
						
						|  |  | 
					
						
						|  | if not check_disk_space(): | 
					
						
						|  | logger.error("Insufficient disk space") | 
					
						
						|  | return None, "β οΈ Insufficient disk space. Free up at least 1 GB.", vram_status | 
					
						
						|  |  | 
					
						
						|  | seed = random.randint(0, 10000) | 
					
						
						|  | logger.info(f"Generating audio for {total_duration}s with seed={seed}, max_steps={max_steps_int}, output_sample_rate={output_sample_rate_int} Hz, bit_depth={bit_depth_int}-bit") | 
					
						
						|  | vram_status = f"Initial VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" | 
					
						
						|  |  | 
					
						
						|  | chunk_num = 0 | 
					
						
						|  | while remaining_duration > 0: | 
					
						
						|  | current_duration = min(max_duration, remaining_duration) | 
					
						
						|  | generation_duration = current_duration | 
					
						
						|  | chunk_num += 1 | 
					
						
						|  | logger.info(f"Generating chunk {chunk_num} ({current_duration}s, VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB)") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if "Red Hot Chili Peppers" in instrumental_prompt: | 
					
						
						|  | chunk_prompt = set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | chunk_prompt = instrumental_prompt | 
					
						
						|  |  | 
					
						
						|  | musicgen_model.set_generation_params( | 
					
						
						|  | duration=generation_duration, | 
					
						
						|  | use_sampling=True, | 
					
						
						|  | top_k=top_k, | 
					
						
						|  | top_p=top_p, | 
					
						
						|  | temperature=temperature, | 
					
						
						|  | cfg_coef=cfg_scale | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | with autocast(dtype=torch.float16): | 
					
						
						|  | torch.manual_seed(seed) | 
					
						
						|  | np.random.seed(seed) | 
					
						
						|  | torch.cuda.manual_seed_all(seed) | 
					
						
						|  | clean_memory() | 
					
						
						|  | if not audio_segments: | 
					
						
						|  | logger.debug("Generating first chunk") | 
					
						
						|  | audio_segment = musicgen_model.generate([chunk_prompt], progress=True)[0].cpu() | 
					
						
						|  | else: | 
					
						
						|  | logger.debug("Generating continuation chunk") | 
					
						
						|  | prev_segment = audio_segments[-1] | 
					
						
						|  | prev_segment = apply_noise_gate(prev_segment, threshold_db=-80, sample_rate=processing_sample_rate) | 
					
						
						|  | prev_segment = balance_stereo(prev_segment, noise_threshold=-40, sample_rate=processing_sample_rate) | 
					
						
						|  | temp_wav_path = f"temp_prev_{int(time.time()*1000)}.wav" | 
					
						
						|  | try: | 
					
						
						|  | logger.debug(f"Exporting previous segment to {temp_wav_path}") | 
					
						
						|  | prev_segment.export(temp_wav_path, format="wav") | 
					
						
						|  | with open(temp_wav_path, "rb") as f: | 
					
						
						|  | mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) | 
					
						
						|  | prev_audio, prev_sr = torchaudio.load(temp_wav_path) | 
					
						
						|  | mmapped_file.close() | 
					
						
						|  | if prev_sr != processing_sample_rate: | 
					
						
						|  | logger.debug(f"Resampling from {prev_sr} to {processing_sample_rate}") | 
					
						
						|  | prev_audio = torchaudio.functional.resample(prev_audio, prev_sr, processing_sample_rate, lowpass_filter_width=64) | 
					
						
						|  | if prev_audio.shape[0] != 2: | 
					
						
						|  | logger.debug(f"Converting to stereo: {prev_audio.shape[0]} channels detected") | 
					
						
						|  | prev_audio = prev_audio.repeat(2, 1)[:, :prev_audio.shape[1]] | 
					
						
						|  | prev_audio = prev_audio.to(device) | 
					
						
						|  | audio_segment = musicgen_model.generate_continuation( | 
					
						
						|  | prompt=prev_audio[:, -int(processing_sample_rate * overlap_duration):], | 
					
						
						|  | prompt_sample_rate=processing_sample_rate, | 
					
						
						|  | descriptions=[chunk_prompt], | 
					
						
						|  | progress=True | 
					
						
						|  | )[0].cpu() | 
					
						
						|  | del prev_audio | 
					
						
						|  | finally: | 
					
						
						|  | try: | 
					
						
						|  | os.remove(temp_wav_path) | 
					
						
						|  | logger.debug(f"Deleted temporary file {temp_wav_path}") | 
					
						
						|  | except OSError: | 
					
						
						|  | logger.warning(f"Failed to delete temporary file {temp_wav_path}") | 
					
						
						|  | clean_memory() | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Error in chunk {chunk_num} generation: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return None, f"β Failed to generate chunk {chunk_num}: {e}", vram_status | 
					
						
						|  |  | 
					
						
						|  | logger.debug(f"Generated audio segment shape: {audio_segment.shape}, dtype: {audio_segment.dtype}") | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | if audio_segment.shape[0] != 2: | 
					
						
						|  | logger.debug(f"Converting to stereo: {audio_segment.shape[0]} channels detected") | 
					
						
						|  | audio_segment = audio_segment.repeat(2, 1)[:, :audio_segment.shape[1]] | 
					
						
						|  |  | 
					
						
						|  | audio_segment = audio_segment.to(dtype=torch.float32) | 
					
						
						|  | audio_segment = torchaudio.functional.resample(audio_segment, 32000, processing_sample_rate, lowpass_filter_width=64) | 
					
						
						|  | audio_np = audio_segment.numpy() | 
					
						
						|  | if audio_np.ndim == 1: | 
					
						
						|  | logger.debug("Converting mono to stereo on CPU") | 
					
						
						|  | audio_np = np.stack([audio_np, audio_np], axis=0) | 
					
						
						|  | if audio_np.shape[0] != 2: | 
					
						
						|  | logger.error(f"Expected stereo audio with shape (2, samples), got shape {audio_np.shape}") | 
					
						
						|  | return None, f"β Invalid audio shape for chunk {chunk_num}: {audio_np.shape}", vram_status | 
					
						
						|  | audio_segment = torch.from_numpy(audio_np).to(dtype=torch.float16) | 
					
						
						|  | logger.debug(f"Converted audio segment to float16, shape: {audio_segment.shape}") | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to process audio segment for chunk {chunk_num}: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return None, f"β Failed to process audio for chunk {chunk_num}: {e}", vram_status | 
					
						
						|  |  | 
					
						
						|  | temp_wav_path = f"temp_audio_{int(time.time()*1000)}.wav" | 
					
						
						|  | logger.debug(f"Saving audio segment to {temp_wav_path}, VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | 
					
						
						|  | try: | 
					
						
						|  | audio_segment_save = audio_segment.to(dtype=torch.float32) | 
					
						
						|  | torchaudio.save(temp_wav_path, audio_segment_save, processing_sample_rate, bits_per_sample=bit_depth_int) | 
					
						
						|  | del audio_segment_save | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to save audio segment for chunk {chunk_num}: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | logger.warning(f"Skipping chunk {chunk_num} due to save error") | 
					
						
						|  | del audio_segment | 
					
						
						|  | clean_memory() | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | clean_memory() | 
					
						
						|  | try: | 
					
						
						|  | with open(temp_wav_path, "rb") as f: | 
					
						
						|  | mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) | 
					
						
						|  | segment = AudioSegment.from_wav(temp_wav_path) | 
					
						
						|  | mmapped_file.close() | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to load WAV file for chunk {chunk_num}: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | logger.warning(f"Skipping chunk {chunk_num} due to WAV load error") | 
					
						
						|  | del audio_segment | 
					
						
						|  | clean_memory() | 
					
						
						|  | continue | 
					
						
						|  | finally: | 
					
						
						|  | try: | 
					
						
						|  | os.remove(temp_wav_path) | 
					
						
						|  | logger.debug(f"Deleted temporary file {temp_wav_path}") | 
					
						
						|  | except OSError: | 
					
						
						|  | logger.warning(f"Failed to delete temporary file {temp_wav_path}") | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | segment = ensure_stereo(segment, processing_sample_rate, sample_width) | 
					
						
						|  | segment = segment - 15 | 
					
						
						|  | if segment.frame_rate != processing_sample_rate: | 
					
						
						|  | logger.debug(f"Setting segment sample rate to {processing_sample_rate}") | 
					
						
						|  | segment = segment.set_frame_rate(processing_sample_rate) | 
					
						
						|  |  | 
					
						
						|  | segment = apply_noise_gate(segment, threshold_db=-80, sample_rate=processing_sample_rate) | 
					
						
						|  | segment = balance_stereo(segment, noise_threshold=-40, sample_rate=processing_sample_rate) | 
					
						
						|  | segment = rms_normalize(segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate) | 
					
						
						|  | segment = apply_eq(segment, sample_rate=processing_sample_rate) | 
					
						
						|  | audio_segments.append(segment) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to process audio segment for chunk {chunk_num}: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | logger.warning(f"Skipping chunk {chunk_num} due to processing error") | 
					
						
						|  | del audio_segment | 
					
						
						|  | clean_memory() | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | del audio_segment | 
					
						
						|  | del audio_np | 
					
						
						|  | clean_memory() | 
					
						
						|  | vram_status = f"VRAM after chunk {chunk_num}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" | 
					
						
						|  | time.sleep(0.1) | 
					
						
						|  | remaining_duration -= current_duration | 
					
						
						|  |  | 
					
						
						|  | if not audio_segments: | 
					
						
						|  | logger.error("No audio segments generated") | 
					
						
						|  | return None, "β No audio segments generated due to errors", vram_status | 
					
						
						|  |  | 
					
						
						|  | logger.info("Combining audio chunks...") | 
					
						
						|  | try: | 
					
						
						|  | final_segment = audio_segments[0][:min(max_duration, total_duration) * 1000] | 
					
						
						|  | final_segment = ensure_stereo(final_segment, processing_sample_rate, sample_width) | 
					
						
						|  | overlap_ms = int(overlap_duration * 1000) | 
					
						
						|  |  | 
					
						
						|  | for i in range(1, len(audio_segments)): | 
					
						
						|  | current_segment = audio_segments[i] | 
					
						
						|  | current_segment = current_segment[:min(max_duration, total_duration - (i * max_duration)) * 1000] | 
					
						
						|  | current_segment = ensure_stereo(current_segment, processing_sample_rate, sample_width) | 
					
						
						|  |  | 
					
						
						|  | if overlap_ms > 0 and len(current_segment) > overlap_ms: | 
					
						
						|  | logger.debug(f"Applying crossfade between chunks {i} and {i+1}") | 
					
						
						|  | prev_overlap = final_segment[-overlap_ms:] | 
					
						
						|  | curr_overlap = current_segment[:overlap_ms] | 
					
						
						|  | prev_wav_path = f"temp_prev_overlap_{int(time.time()*1000)}.wav" | 
					
						
						|  | curr_wav_path = f"temp_curr_overlap_{int(time.time()*1000)}.wav" | 
					
						
						|  | try: | 
					
						
						|  | prev_overlap.export(prev_wav_path, format="wav") | 
					
						
						|  | curr_overlap.export(curr_wav_path, format="wav") | 
					
						
						|  | clean_memory() | 
					
						
						|  | prev_audio, _ = torchaudio.load(prev_wav_path) | 
					
						
						|  | curr_audio, _ = torchaudio.load(curr_wav_path) | 
					
						
						|  | num_samples = min(prev_audio.shape[1], curr_audio.shape[1]) | 
					
						
						|  | num_samples = num_samples - (num_samples % 2) | 
					
						
						|  | if num_samples <= 0: | 
					
						
						|  | logger.warning(f"Skipping crossfade for chunk {i+1} due to insufficient samples") | 
					
						
						|  | final_segment += current_segment | 
					
						
						|  | continue | 
					
						
						|  | blended_samples = torch.zeros(2, num_samples, dtype=torch.float32) | 
					
						
						|  | prev_samples = prev_audio[:, :num_samples] | 
					
						
						|  | curr_samples = curr_audio[:, :num_samples] | 
					
						
						|  | hann_window = torch.hann_window(num_samples, periodic=False) | 
					
						
						|  | fade_out = hann_window.flip(0) | 
					
						
						|  | fade_in = hann_window | 
					
						
						|  | blended_samples = (prev_samples * fade_out + curr_samples * fade_in) | 
					
						
						|  | blended_samples = (blended_samples * (2**23 if sample_width == 3 else 32767)).to(torch.int32 if sample_width == 3 else torch.int16) | 
					
						
						|  | temp_crossfade_path = f"temp_crossfade_{int(time.time()*1000)}.wav" | 
					
						
						|  | torchaudio.save(temp_crossfade_path, blended_samples, processing_sample_rate, bits_per_sample=bit_depth_int) | 
					
						
						|  | blended_segment = AudioSegment.from_wav(temp_crossfade_path) | 
					
						
						|  | blended_segment = ensure_stereo(blended_segment, processing_sample_rate, sample_width) | 
					
						
						|  | blended_segment = rms_normalize(blended_segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate) | 
					
						
						|  | final_segment = final_segment[:-overlap_ms] + blended_segment + current_segment[overlap_ms:] | 
					
						
						|  | finally: | 
					
						
						|  | for temp_path in [prev_wav_path, curr_wav_path, temp_crossfade_path]: | 
					
						
						|  | try: | 
					
						
						|  | if os.path.exists(temp_path): | 
					
						
						|  | os.remove(temp_path) | 
					
						
						|  | logger.debug(f"Deleted temporary file {temp_path}") | 
					
						
						|  | except OSError: | 
					
						
						|  | logger.warning(f"Failed to delete temporary file {temp_path}") | 
					
						
						|  | else: | 
					
						
						|  | logger.debug(f"Concatenating chunk {i+1} without crossfade") | 
					
						
						|  | final_segment += current_segment | 
					
						
						|  |  | 
					
						
						|  | final_segment = final_segment[:total_duration * 1000] | 
					
						
						|  | logger.info("Post-processing final track...") | 
					
						
						|  | final_segment = apply_noise_gate(final_segment, threshold_db=-80, sample_rate=processing_sample_rate) | 
					
						
						|  | final_segment = balance_stereo(final_segment, noise_threshold=-40, sample_rate=processing_sample_rate) | 
					
						
						|  | final_segment = rms_normalize(final_segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate) | 
					
						
						|  | final_segment = apply_eq(final_segment, sample_rate=processing_sample_rate) | 
					
						
						|  | final_segment = apply_fade(final_segment) | 
					
						
						|  | final_segment = final_segment - 10 | 
					
						
						|  | final_segment = final_segment.set_frame_rate(output_sample_rate_int) | 
					
						
						|  |  | 
					
						
						|  | mp3_path = f"output_adjusted_volume_{int(time.time())}.mp3" | 
					
						
						|  | logger.info("β οΈ WARNING: Audio is set to safe levels (~ -23 dBFS RMS, -3 dBFS peak). Start playback at LOW volume (10-20%) and adjust gradually.") | 
					
						
						|  | logger.info("VERIFY: Open the file in Audacity to check for high-pitched tones and quality. RMS should be ~ -23 dBFS, peaks β€ -3 dBFS. Report any issues.") | 
					
						
						|  | try: | 
					
						
						|  | clean_memory() | 
					
						
						|  | logger.debug(f"Exporting final audio to {mp3_path} with bitrate {bitrate}, sample rate {output_sample_rate_int} Hz, bit depth {bit_depth_int}-bit") | 
					
						
						|  | final_segment.export( | 
					
						
						|  | mp3_path, | 
					
						
						|  | format="mp3", | 
					
						
						|  | bitrate=bitrate, | 
					
						
						|  | tags={"title": "GhostAI Instrumental", "artist": "GhostAI"} | 
					
						
						|  | ) | 
					
						
						|  | logger.info(f"Final audio saved to {mp3_path}") | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Error exporting MP3 with bitrate {bitrate}: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | fallback_path = f"fallback_output_{int(time.time())}.mp3" | 
					
						
						|  | try: | 
					
						
						|  | final_segment.export(fallback_path, format="mp3", bitrate="128k") | 
					
						
						|  | logger.info(f"Final audio saved to fallback: {fallback_path} with 128 kbps") | 
					
						
						|  | mp3_path = fallback_path | 
					
						
						|  | except Exception as fallback_e: | 
					
						
						|  | logger.error(f"Failed to save fallback MP3: {fallback_e}") | 
					
						
						|  | return None, f"β Failed to export audio: {fallback_e}", vram_status | 
					
						
						|  |  | 
					
						
						|  | vram_status = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" | 
					
						
						|  | logger.info(f"Generation completed in {time.time() - start_time:.2f} seconds") | 
					
						
						|  | return mp3_path, "β
 Done! Generated track with adjusted volume levels. Check for quality in Audacity.", vram_status | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to combine audio chunks: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return None, f"β Failed to combine audio: {e}", vram_status | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Generation failed: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | return None, f"β Generation failed: {e}", vram_status | 
					
						
						|  | finally: | 
					
						
						|  | clean_memory() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def clear_inputs(): | 
					
						
						|  | logger.info("Clearing input fields") | 
					
						
						|  | return "", 5.8, 18, 0.88, 0.15, 30, 120, "none", "none", "none", "none", "none", -23.0, "default", 1300, "128k", "44100", "16" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | css = """ | 
					
						
						|  | body { | 
					
						
						|  | background: #121212; | 
					
						
						|  | color: #E6E6E6; | 
					
						
						|  | font-family: 'Arial', sans-serif; | 
					
						
						|  | } | 
					
						
						|  | .header-container { | 
					
						
						|  | text-align: center; | 
					
						
						|  | padding: 15px 20px; | 
					
						
						|  | background: #1E1E1E; | 
					
						
						|  | border-bottom: 2px solid #00C853; | 
					
						
						|  | } | 
					
						
						|  | #ghost-logo { | 
					
						
						|  | font-size: 48px; | 
					
						
						|  | color: #00C853; | 
					
						
						|  | } | 
					
						
						|  | h1 { | 
					
						
						|  | color: #FFD600; | 
					
						
						|  | font-size: 28px; | 
					
						
						|  | font-weight: bold; | 
					
						
						|  | } | 
					
						
						|  | h3 { | 
					
						
						|  | color: #FFD600; | 
					
						
						|  | font-size: 20px; | 
					
						
						|  | font-weight: bold; | 
					
						
						|  | } | 
					
						
						|  | p { | 
					
						
						|  | color: #B0BEC5; | 
					
						
						|  | font-size: 14px; | 
					
						
						|  | } | 
					
						
						|  | .input-container, .settings-container, .output-container, .logs-container { | 
					
						
						|  | max-width: 1200px; | 
					
						
						|  | margin: 20px auto; | 
					
						
						|  | padding: 20px; | 
					
						
						|  | background: #212121; | 
					
						
						|  | border: 1px solid #424242; | 
					
						
						|  | border-radius: 8px; | 
					
						
						|  | } | 
					
						
						|  | .textbox { | 
					
						
						|  | background: #2C2C2C; | 
					
						
						|  | border: 1px solid #B0BEC5; | 
					
						
						|  | color: #E6E6E6; | 
					
						
						|  | font-size: 16px; | 
					
						
						|  | } | 
					
						
						|  | .genre-buttons, .bitrate-buttons, .sample-rate-buttons, .bit-depth-buttons { | 
					
						
						|  | display: flex; | 
					
						
						|  | justify-content: center; | 
					
						
						|  | flex-wrap: wrap; | 
					
						
						|  | gap: 10px; | 
					
						
						|  | } | 
					
						
						|  | .genre-btn, .bitrate-btn, .sample-rate-btn, .bit-depth-btn, button { | 
					
						
						|  | background: #0288D1; | 
					
						
						|  | border: 2px solid transparent; | 
					
						
						|  | color: #FFFFFF; | 
					
						
						|  | padding: 10px 20px; | 
					
						
						|  | border-radius: 5px; | 
					
						
						|  | font-size: 16px; | 
					
						
						|  | transition: all 0.3s ease; | 
					
						
						|  | } | 
					
						
						|  | button:hover { | 
					
						
						|  | background: #03A9F4; | 
					
						
						|  | cursor: pointer; | 
					
						
						|  | } | 
					
						
						|  | button:active, .genre-btn.active, .bitrate-btn.active, .sample-rate-btn.active, .bit-depth-btn.active { | 
					
						
						|  | border: 2px solid #00C853 !important; | 
					
						
						|  | background: #01579B; | 
					
						
						|  | color: #FFFFFF; | 
					
						
						|  | } | 
					
						
						|  | .gradio-container { | 
					
						
						|  | padding: 20px; | 
					
						
						|  | } | 
					
						
						|  | .group-container { | 
					
						
						|  | margin-bottom: 20px; | 
					
						
						|  | padding: 15px; | 
					
						
						|  | border: 1px solid #424242; | 
					
						
						|  | border-radius: 8px; | 
					
						
						|  | } | 
					
						
						|  | .slider-label, .dropdown-label { | 
					
						
						|  | color: #FFD600; | 
					
						
						|  | font-size: 16px; | 
					
						
						|  | font-weight: bold; | 
					
						
						|  | } | 
					
						
						|  | .slider, .dropdown { | 
					
						
						|  | background: #2C2C2C; | 
					
						
						|  | color: #E6E6E6; | 
					
						
						|  | } | 
					
						
						|  | .output-container label, .logs-container label { | 
					
						
						|  | color: #FFD600; | 
					
						
						|  | font-size: 16px; | 
					
						
						|  | font-weight: bold; | 
					
						
						|  | } | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger.info("Building Gradio interface...") | 
					
						
						|  | with gr.Blocks(css=css) as demo: | 
					
						
						|  | gr.Markdown(""" | 
					
						
						|  | <div class="header-container"> | 
					
						
						|  | <div id="ghost-logo">π»</div> | 
					
						
						|  | <h1>GhostAI Music Generator πΉ</h1> | 
					
						
						|  | <p>Create Instrumental Tracks with Ease</p> | 
					
						
						|  | </div> | 
					
						
						|  | """) | 
					
						
						|  |  | 
					
						
						|  | with gr.Column(elem_classes="input-container"): | 
					
						
						|  | gr.Markdown("### πΈ Prompt Settings") | 
					
						
						|  | instrumental_prompt = gr.Textbox( | 
					
						
						|  | label="Instrumental Prompt βοΈ", | 
					
						
						|  | placeholder="Click a genre button or type your own instrumental prompt", | 
					
						
						|  | lines=4, | 
					
						
						|  | elem_classes="textbox" | 
					
						
						|  | ) | 
					
						
						|  | with gr.Row(elem_classes="genre-buttons"): | 
					
						
						|  | rhcp_btn = gr.Button("Red Hot Chili Peppers πΆοΈ", elem_classes="genre-btn") | 
					
						
						|  | nirvana_btn = gr.Button("Nirvana Grunge πΈ", elem_classes="genre-btn") | 
					
						
						|  | pearl_jam_btn = gr.Button("Pearl Jam Grunge π¦ͺ", elem_classes="genre-btn") | 
					
						
						|  | soundgarden_btn = gr.Button("Soundgarden Grunge π", elem_classes="genre-btn") | 
					
						
						|  | foo_fighters_btn = gr.Button("Foo Fighters π€", elem_classes="genre-btn") | 
					
						
						|  | smashing_pumpkins_btn = gr.Button("Smashing Pumpkins π", elem_classes="genre-btn") | 
					
						
						|  | radiohead_btn = gr.Button("Radiohead π§ ", elem_classes="genre-btn") | 
					
						
						|  | classic_rock_btn = gr.Button("Metallica Heavy Metal πΈ", elem_classes="genre-btn") | 
					
						
						|  | alternative_rock_btn = gr.Button("Alternative Rock π΅", elem_classes="genre-btn") | 
					
						
						|  | post_punk_btn = gr.Button("Post-Punk π€", elem_classes="genre-btn") | 
					
						
						|  | indie_rock_btn = gr.Button("Indie Rock π€", elem_classes="genre-btn") | 
					
						
						|  | funk_rock_btn = gr.Button("Funk Rock πΊ", elem_classes="genre-btn") | 
					
						
						|  | detroit_techno_btn = gr.Button("Detroit Techno ποΈ", elem_classes="genre-btn") | 
					
						
						|  | deep_house_btn = gr.Button("Deep House π ", elem_classes="genre-btn") | 
					
						
						|  |  | 
					
						
						|  | with gr.Column(elem_classes="settings-container"): | 
					
						
						|  | gr.Markdown("### βοΈ API Settings") | 
					
						
						|  | with gr.Group(elem_classes="group-container"): | 
					
						
						|  | cfg_scale = gr.Slider( | 
					
						
						|  | label="CFG Scale π―", | 
					
						
						|  | minimum=1.0, | 
					
						
						|  | maximum=10.0, | 
					
						
						|  | value=5.8, | 
					
						
						|  | step=0.1, | 
					
						
						|  | info="Controls how closely the music follows the prompt." | 
					
						
						|  | ) | 
					
						
						|  | top_k = gr.Slider( | 
					
						
						|  | label="Top-K Sampling π’", | 
					
						
						|  | minimum=10, | 
					
						
						|  | maximum=500, | 
					
						
						|  | value=18, | 
					
						
						|  | step=10, | 
					
						
						|  | info="Limits sampling to the top k most likely tokens." | 
					
						
						|  | ) | 
					
						
						|  | top_p = gr.Slider( | 
					
						
						|  | label="Top-P Sampling π°", | 
					
						
						|  | minimum=0.0, | 
					
						
						|  | maximum=1.0, | 
					
						
						|  | value=0.88, | 
					
						
						|  | step=0.05, | 
					
						
						|  | info="Keeps tokens with cumulative probability above p." | 
					
						
						|  | ) | 
					
						
						|  | temperature = gr.Slider( | 
					
						
						|  | label="Temperature π₯", | 
					
						
						|  | minimum=0.1, | 
					
						
						|  | maximum=2.0, | 
					
						
						|  | value=0.15, | 
					
						
						|  | step=0.1, | 
					
						
						|  | info="Controls randomness; lower values reduce noise." | 
					
						
						|  | ) | 
					
						
						|  | total_duration = gr.Dropdown( | 
					
						
						|  | label="Song Length β³ (seconds)", | 
					
						
						|  | choices=[30, 60, 90, 120], | 
					
						
						|  | value=30, | 
					
						
						|  | info="Select the total duration of the track." | 
					
						
						|  | ) | 
					
						
						|  | bpm = gr.Slider( | 
					
						
						|  | label="Tempo π΅ (BPM)", | 
					
						
						|  | minimum=60, | 
					
						
						|  | maximum=180, | 
					
						
						|  | value=120, | 
					
						
						|  | step=1, | 
					
						
						|  | info="Beats per minute to set the track's tempo." | 
					
						
						|  | ) | 
					
						
						|  | drum_beat = gr.Dropdown( | 
					
						
						|  | label="Drum Beat π₯", | 
					
						
						|  | choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"], | 
					
						
						|  | value="none", | 
					
						
						|  | info="Select a drum beat style to influence the rhythm." | 
					
						
						|  | ) | 
					
						
						|  | synthesizer = gr.Dropdown( | 
					
						
						|  | label="Synthesizer πΉ", | 
					
						
						|  | choices=["none", "analog synth", "digital pad", "arpeggiated synth"], | 
					
						
						|  | value="none", | 
					
						
						|  | info="Select a synthesizer style for electronic accents." | 
					
						
						|  | ) | 
					
						
						|  | rhythmic_steps = gr.Dropdown( | 
					
						
						|  | label="Rhythmic Steps π£", | 
					
						
						|  | choices=["none", "syncopated steps", "steady steps", "complex steps"], | 
					
						
						|  | value="none", | 
					
						
						|  | info="Select a rhythmic step style to enhance the beat." | 
					
						
						|  | ) | 
					
						
						|  | bass_style = gr.Dropdown( | 
					
						
						|  | label="Bass Style πΈ", | 
					
						
						|  | choices=["none", "slap bass", "deep bass", "melodic bass"], | 
					
						
						|  | value="none", | 
					
						
						|  | info="Select a bass style to shape the low end." | 
					
						
						|  | ) | 
					
						
						|  | guitar_style = gr.Dropdown( | 
					
						
						|  | label="Guitar Style πΈ", | 
					
						
						|  | choices=["none", "distorted", "clean", "jangle"], | 
					
						
						|  | value="none", | 
					
						
						|  | info="Select a guitar style to define the riffs." | 
					
						
						|  | ) | 
					
						
						|  | target_volume = gr.Slider( | 
					
						
						|  | label="Target Volume ποΈ (dBFS RMS)", | 
					
						
						|  | minimum=-30.0, | 
					
						
						|  | maximum=-20.0, | 
					
						
						|  | value=-23.0, | 
					
						
						|  | step=1.0, | 
					
						
						|  | info="Adjust output loudness (-23 dBFS is standard, -20 dBFS is louder, -30 dBFS is quieter)." | 
					
						
						|  | ) | 
					
						
						|  | preset = gr.Dropdown( | 
					
						
						|  | label="Preset Configuration ποΈ", | 
					
						
						|  | choices=["default", "rock", "techno", "grunge", "indie", "funk_rock"], | 
					
						
						|  | value="default", | 
					
						
						|  | info="Select a preset optimized for specific genres." | 
					
						
						|  | ) | 
					
						
						|  | max_steps = gr.Dropdown( | 
					
						
						|  | label="Max Steps per Chunk π", | 
					
						
						|  | choices=[1000, 1200, 1300, 1500], | 
					
						
						|  | value=1300, | 
					
						
						|  | info="Number of generation steps per chunk (1300=~26s, extended to 30s)." | 
					
						
						|  | ) | 
					
						
						|  | bitrate_state = gr.State(value="128k") | 
					
						
						|  | sample_rate_state = gr.State(value="44100") | 
					
						
						|  | bit_depth_state = gr.State(value="16") | 
					
						
						|  | with gr.Row(elem_classes="bitrate-buttons"): | 
					
						
						|  | bitrate_128_btn = gr.Button("Set Bitrate to 128 kbps", elem_classes="bitrate-btn") | 
					
						
						|  | bitrate_192_btn = gr.Button("Set Bitrate to 192 kbps", elem_classes="bitrate-btn") | 
					
						
						|  | bitrate_320_btn = gr.Button("Set Bitrate to 320 kbps", elem_classes="bitrate-btn") | 
					
						
						|  | with gr.Row(elem_classes="sample-rate-buttons"): | 
					
						
						|  | sample_rate_22050_btn = gr.Button("Set Sampling Rate to 22.05 kHz", elem_classes="sample-rate-btn") | 
					
						
						|  | sample_rate_44100_btn = gr.Button("Set Sampling Rate to 44.1 kHz", elem_classes="sample-rate-btn") | 
					
						
						|  | sample_rate_48000_btn = gr.Button("Set Sampling Rate to 48 kHz", elem_classes="sample-rate-btn") | 
					
						
						|  | with gr.Row(elem_classes="bit-depth-buttons"): | 
					
						
						|  | bit_depth_16_btn = gr.Button("Set Bit Depth to 16-bit", elem_classes="bit-depth-btn") | 
					
						
						|  | bit_depth_24_btn = gr.Button("Set Bit Depth to 24-bit", elem_classes="bit-depth-btn") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(elem_classes="action-buttons"): | 
					
						
						|  | gen_btn = gr.Button("Generate Music π") | 
					
						
						|  | clr_btn = gr.Button("Clear Inputs π§Ή") | 
					
						
						|  |  | 
					
						
						|  | with gr.Column(elem_classes="output-container"): | 
					
						
						|  | gr.Markdown("### π§ Output") | 
					
						
						|  | out_audio = gr.Audio(label="Generated Instrumental Track π΅", type="filepath") | 
					
						
						|  | status = gr.Textbox(label="Status π’", interactive=False) | 
					
						
						|  | vram_status = gr.Textbox(label="VRAM Usage π", interactive=False, value="") | 
					
						
						|  |  | 
					
						
						|  | with gr.Column(elem_classes="logs-container"): | 
					
						
						|  | gr.Markdown("### π Logs") | 
					
						
						|  | log_output = gr.Textbox(label="Last Log File Contents", lines=20, interactive=False) | 
					
						
						|  | log_btn = gr.Button("View Last Log π") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def update_button_styles(selected_button): | 
					
						
						|  | buttons = [ | 
					
						
						|  | "rhcp_btn", "nirvana_btn", "pearl_jam_btn", "soundgarden_btn", "foo_fighters_btn", | 
					
						
						|  | "smashing_pumpkins_btn", "radiohead_btn", "classic_rock_btn", "alternative_rock_btn", | 
					
						
						|  | "post_punk_btn", "indie_rock_btn", "funk_rock_btn", "detroit_techno_btn", "deep_house_btn", | 
					
						
						|  | "bitrate_128_btn", "bitrate_192_btn", "bitrate_320_btn", | 
					
						
						|  | "sample_rate_22050_btn", "sample_rate_44100_btn", "sample_rate_48000_btn", | 
					
						
						|  | "bit_depth_16_btn", "bit_depth_24_btn" | 
					
						
						|  | ] | 
					
						
						|  | script = """ | 
					
						
						|  | <script> | 
					
						
						|  | document.querySelectorAll('.genre-btn, .bitrate-btn, .sample-rate-btn, .bit-depth-btn').forEach(btn => { | 
					
						
						|  | btn.classList.remove('active'); | 
					
						
						|  | }); | 
					
						
						|  | document.querySelector('#""" + selected_button + """').classList.add('active'); | 
					
						
						|  | </script> | 
					
						
						|  | """ | 
					
						
						|  | return script | 
					
						
						|  |  | 
					
						
						|  | rhcp_btn.click(set_red_hot_chili_peppers_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, gr.State(value=1)], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("rhcp_btn")) | 
					
						
						|  | nirvana_btn.click(set_nirvana_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("nirvana_btn")) | 
					
						
						|  | pearl_jam_btn.click(set_pearl_jam_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("pearl_jam_btn")) | 
					
						
						|  | soundgarden_btn.click(set_soundgarden_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("soundgarden_btn")) | 
					
						
						|  | foo_fighters_btn.click(set_foo_fighters_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("foo_fighters_btn")) | 
					
						
						|  | smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("smashing_pumpkins_btn")) | 
					
						
						|  | radiohead_btn.click(set_radiohead_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("radiohead_btn")) | 
					
						
						|  | classic_rock_btn.click(set_classic_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("classic_rock_btn")) | 
					
						
						|  | alternative_rock_btn.click(set_alternative_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("alternative_rock_btn")) | 
					
						
						|  | post_punk_btn.click(set_post_punk_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("post_punk_btn")) | 
					
						
						|  | indie_rock_btn.click(set_indie_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("indie_rock_btn")) | 
					
						
						|  | funk_rock_btn.click(set_funk_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("funk_rock_btn")) | 
					
						
						|  | detroit_techno_btn.click(set_detroit_techno_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("detroit_techno_btn")) | 
					
						
						|  | deep_house_btn.click(set_deep_house_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt).then(None, None, None, js=update_button_styles("deep_house_btn")) | 
					
						
						|  | bitrate_128_btn.click(set_bitrate_128, inputs=None, outputs=bitrate_state).then(None, None, None, js=update_button_styles("bitrate_128_btn")) | 
					
						
						|  | bitrate_192_btn.click(set_bitrate_192, inputs=None, outputs=bitrate_state).then(None, None, None, js=update_button_styles("bitrate_192_btn")) | 
					
						
						|  | bitrate_320_btn.click(set_bitrate_320, inputs=None, outputs=bitrate_state).then(None, None, None, js=update_button_styles("bitrate_320_btn")) | 
					
						
						|  | sample_rate_22050_btn.click(set_sample_rate_22050, inputs=None, outputs=sample_rate_state).then(None, None, None, js=update_button_styles("sample_rate_22050_btn")) | 
					
						
						|  | sample_rate_44100_btn.click(set_sample_rate_44100, inputs=None, outputs=sample_rate_state).then(None, None, None, js=update_button_styles("sample_rate_44100_btn")) | 
					
						
						|  | sample_rate_48000_btn.click(set_sample_rate_48000, inputs=None, outputs=sample_rate_state).then(None, None, None, js=update_button_styles("sample_rate_48000_btn")) | 
					
						
						|  | bit_depth_16_btn.click(set_bit_depth_16, inputs=None, outputs=bit_depth_state).then(None, None, None, js=update_button_styles("bit_depth_16_btn")) | 
					
						
						|  | bit_depth_24_btn.click(set_bit_depth_24, inputs=None, outputs=bit_depth_state).then(None, None, None, js=update_button_styles("bit_depth_24_btn")) | 
					
						
						|  | gen_btn.click( | 
					
						
						|  | generate_music_wrapper, | 
					
						
						|  | inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, vram_status, bitrate_state, sample_rate_state, bit_depth_state], | 
					
						
						|  | outputs=[out_audio, status, vram_status] | 
					
						
						|  | ) | 
					
						
						|  | clr_btn.click( | 
					
						
						|  | clear_inputs, | 
					
						
						|  | inputs=None, | 
					
						
						|  | outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state] | 
					
						
						|  | ) | 
					
						
						|  | log_btn.click( | 
					
						
						|  | get_latest_log, | 
					
						
						|  | inputs=None, | 
					
						
						|  | outputs=log_output | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger.info("Launching Gradio UI at http://localhost:9999...") | 
					
						
						|  | try: | 
					
						
						|  | app = demo.launch( | 
					
						
						|  | server_name="0.0.0.0", | 
					
						
						|  | server_port=9999, | 
					
						
						|  | share=True, | 
					
						
						|  | inbrowser=False, | 
					
						
						|  | show_error=True | 
					
						
						|  | ) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to launch Gradio UI: {e}") | 
					
						
						|  | logger.error(traceback.format_exc()) | 
					
						
						|  | sys.exit(1) |