import os import tempfile from pathlib import Path from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import CLIPProcessor, CLIPModel try: from transformers import ClapModel, ClapProcessor CLAP_AVAILABLE = True CLAP_METHOD = "transformers" except ImportError as e1: CLAP_AVAILABLE = False CLAP_METHOD = None # Check MERT availability try: from transformers import AutoModel, Wav2Vec2FeatureExtractor MERT_AVAILABLE = True MERT_METHOD = "transformers" except ImportError as e2: MERT_AVAILABLE = False MERT_METHOD = None import torch import torchaudio from PIL import Image import requests import numpy as np import io import logging import librosa import soundfile as sf import scipy.signal # Set environment to disable librosa caching os.environ['LIBROSA_CACHE_DIR'] = '/tmp' os.environ['JOBLIB_TEMP_FOLDER'] = '/tmp' # Disable librosa caching completely to avoid the __o_fold error os.environ['LIBROSA_CACHE_LEVEL'] = '0' os.environ['LIBROSA_CACHE_COMPRESS'] = '0' # Check pitch-aware model availability try: # Try to use a simpler pitch-aware approach with librosa import librosa PITCH_AWARE_AVAILABLE = True PITCH_METHOD = "librosa_chroma" except ImportError: PITCH_AWARE_AVAILABLE = False PITCH_METHOD = None # Fusion configuration FUSION_MODE = os.environ.get('FUSION_MODE', 'VECTOR_CONCAT') # VECTOR_CONCAT or SCORE_FUSION FUSION_ALPHA = float(os.environ.get('FUSION_ALPHA', '0.6')) # Alpha for score fusion ENABLE_PITCH_FUSION = os.environ.get('ENABLE_PITCH_FUSION', 'false').lower() == 'true' and PITCH_AWARE_AVAILABLE ENABLE_MERT_FUSION = os.environ.get('ENABLE_MERT_FUSION', 'false').lower() == 'true' and MERT_AVAILABLE # Audio processing limits MAX_AUDIO_DURATION_SEC = int(os.environ.get('MAX_AUDIO_DURATION_SEC', '600')) # 10 minutes default # Set up cache directories cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache') os.makedirs(cache_dir, exist_ok=True) os.environ['TRANSFORMERS_CACHE'] = cache_dir os.environ['HF_HOME'] = cache_dir os.environ['TORCH_HOME'] = cache_dir # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="CLIP Service", version="1.0.0") def sanitize_for_json(embedding): """Sanitize embedding for JSON serialization""" if isinstance(embedding, np.ndarray): # Ensure finite values and convert to list embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0) return embedding.tolist() elif isinstance(embedding, list): # Check each element for finite values return [float(x) if np.isfinite(x) else 0.0 for x in embedding] else: return embedding # Log CLAP, MERT, and pitch-aware availability after logger is initialized logger.info(f"CLAP availability: {CLAP_AVAILABLE}, method: {CLAP_METHOD}") logger.info(f"MERT availability: {MERT_AVAILABLE}, method: {MERT_METHOD}") logger.info(f"Pitch-aware availability: {PITCH_AWARE_AVAILABLE}, method: {PITCH_METHOD}") logger.info(f"Pitch fusion enabled: {ENABLE_PITCH_FUSION}") logger.info(f"MERT fusion enabled: {ENABLE_MERT_FUSION}") logger.info(f"Fusion mode: {FUSION_MODE}") if FUSION_MODE == 'SCORE_FUSION': logger.info(f"Fusion alpha: {FUSION_ALPHA}") class CLIPService: def __init__(self): logger.info("Loading CLIP model...") self.clap_model = None self.clap_processor = None self.mert_model = None self.mert_processor = None # Simple in-memory cache for pitch features keyed by audio hash # Using a dict avoids the "unhashable type: numpy.ndarray" error we hit with functools.lru_cache self.pitch_feature_cache: dict[int, np.ndarray] = {} try: # Use CPU for Hugging Face free tier self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Load CLIP model with explicit cache directory logger.info("Loading CLIP model from HuggingFace...") self.clip_model = CLIPModel.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=cache_dir, local_files_only=False ).to(self.device) logger.info("Loading CLIP processor...") use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true' logger.info(f"Using fast processor: {use_fast}") self.clip_processor = CLIPProcessor.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=cache_dir, local_files_only=False, use_fast=use_fast ) logger.info(f"CLIP model loaded successfully on {self.device}") except Exception as e: logger.error(f"Failed to load CLIP model: {str(e)}") logger.error(f"Error type: {type(e).__name__}") raise RuntimeError(f"CLIP model loading failed: {str(e)}") def _load_clap_model(self): """Load CLAP model on demand""" if not CLAP_AVAILABLE: raise RuntimeError("CLAP model not available - transformers version may not support CLAP") if self.clap_model is None: logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...") try: if CLAP_METHOD == "transformers": logger.info("Loading CLAP model from HuggingFace...") self.clap_model = ClapModel.from_pretrained( "laion/clap-htsat-unfused", cache_dir=cache_dir, local_files_only=False ).to(self.device) logger.info("Loading CLAP processor...") use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true' self.clap_processor = ClapProcessor.from_pretrained( "laion/clap-htsat-unfused", cache_dir=cache_dir, local_files_only=False, use_fast=use_fast ) logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}") except Exception as e: logger.error(f"Failed to load CLAP model: {str(e)}") logger.error(f"Error type: {type(e).__name__}") raise RuntimeError(f"CLAP model loading failed: {str(e)}") def _load_mert_model(self): """Load MERT model on demand""" if not MERT_AVAILABLE: raise RuntimeError("MERT model not available - transformers version may not support MERT") if self.mert_model is None: logger.info(f"Loading MERT model on demand using {MERT_METHOD} method...") try: logger.info("Loading MERT model from HuggingFace...") self.mert_model = AutoModel.from_pretrained( "m-a-p/MERT-v1-95M", trust_remote_code=True, cache_dir=cache_dir, local_files_only=False ).to(self.device) # Guard against missing encoder (stale HF cache issue) if not hasattr(self.mert_model, "encoder"): raise RuntimeError("MERT weights not loaded - clear HF cache and retry") logger.info("Loading MERT processor...") use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true' self.mert_processor = Wav2Vec2FeatureExtractor.from_pretrained( "m-a-p/MERT-v1-95M", cache_dir=cache_dir, local_files_only=False, use_fast=use_fast ) logger.info(f"MERT model loaded successfully on {self.device} using {MERT_METHOD}") except Exception as e: logger.error(f"Failed to load MERT model: {str(e)}") logger.error(f"Error type: {type(e).__name__}") raise RuntimeError(f"MERT model loading failed: {str(e)}") def extract_pitch_features(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray: """Extract pitch-aware features using numpy/scipy (avoiding all librosa caching issues) with a lightweight dict cache""" cache_key = hash(audio_array.tobytes()) # Fast path: return cached result if we've already seen this audio chunk if cache_key in self.pitch_feature_cache: return self.pitch_feature_cache[cache_key] # Slow path: compute fresh features and store them features = self._extract_pitch_features_impl(audio_array, sample_rate) self.pitch_feature_cache[cache_key] = features return features def _extract_pitch_features_impl(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray: """Implementation of pitch feature extraction (cached)""" try: # Use pure numpy/scipy implementations to avoid all librosa caching issues features = [] # Extract basic audio features using numpy/scipy only try: # Basic spectral features using numpy FFT spectral_features = self._extract_spectral_features_numpy(audio_array, sample_rate) features.extend(spectral_features) logger.info("✓ Spectral features extracted (numpy)") except Exception as e: logger.warning(f"Spectral feature extraction failed: {e}, using zeros") features.extend(np.zeros(20)) # 20 spectral features try: # Basic temporal features temporal_features = self._extract_temporal_features_numpy(audio_array, sample_rate) features.extend(temporal_features) logger.info("✓ Temporal features extracted (numpy)") except Exception as e: logger.warning(f"Temporal feature extraction failed: {e}, using zeros") features.extend(np.zeros(15)) # 15 temporal features try: # Basic frequency domain features frequency_features = self._extract_frequency_features_numpy(audio_array, sample_rate) features.extend(frequency_features) logger.info("✓ Frequency features extracted (numpy)") except Exception as e: logger.warning(f"Frequency feature extraction failed: {e}, using zeros") features.extend(np.zeros(25)) # 25 frequency features # Simple tempo estimation (fallback) try: tempo = self._estimate_tempo_numpy(audio_array, sample_rate) features.append(tempo) logger.info(f"✓ Tempo estimated: {tempo:.1f} BPM") except Exception as e: logger.warning(f"Tempo estimation failed: {e}, using default") features.append(120.0) # Default BPM # Convert to numpy array and check for NaN/inf values pitch_features = np.array(features, dtype=np.float32) # Replace any NaN or inf values with 0 pitch_features = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0) # Ensure we have the expected 85 dimensions if len(pitch_features) < 85: # Pad with zeros if needed padding = np.zeros(85 - len(pitch_features), dtype=np.float32) pitch_features = np.concatenate([pitch_features, padding]) elif len(pitch_features) > 85: # Truncate if too long pitch_features = pitch_features[:85] # L2 normalize norm = np.linalg.norm(pitch_features) if norm > 0: pitch_features = pitch_features / norm else: # If norm is 0, create a small non-zero vector pitch_features = np.ones(85, dtype=np.float32) * 0.001 pitch_features = pitch_features / np.linalg.norm(pitch_features) # Final check for finite values pitch_features = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0) # Result is automatically cached by LRU decorator logger.info(f"Extracted pitch features: {len(pitch_features)} dimensions") return pitch_features except Exception as e: logger.error(f"Error extracting pitch features: {str(e)}") # Return normalized zero vector if extraction fails zero_features = np.ones(85, dtype=np.float32) * 0.001 return zero_features / np.linalg.norm(zero_features) def _extract_spectral_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list: """Extract spectral features using only numpy (no librosa)""" # Compute FFT fft = np.fft.fft(audio_array) magnitude = np.abs(fft) freqs = np.fft.fftfreq(len(audio_array), 1/sample_rate) # Only use positive frequencies pos_freqs = freqs[:len(freqs)//2] pos_magnitude = magnitude[:len(magnitude)//2] # Spectral centroid spectral_centroid = np.sum(pos_freqs * pos_magnitude) / np.sum(pos_magnitude) # Spectral rolloff (95% of energy) cumsum = np.cumsum(pos_magnitude) rolloff_idx = np.where(cumsum >= 0.95 * cumsum[-1])[0] spectral_rolloff = pos_freqs[rolloff_idx[0]] if len(rolloff_idx) > 0 else 0 # Spectral spread spectral_spread = np.sqrt(np.sum(((pos_freqs - spectral_centroid) ** 2) * pos_magnitude) / np.sum(pos_magnitude)) # Zero crossing rate zero_crossings = np.where(np.diff(np.sign(audio_array)))[0] zcr = len(zero_crossings) / len(audio_array) # RMS energy rms = np.sqrt(np.mean(audio_array ** 2)) # Basic spectral features features = [ spectral_centroid / sample_rate, # Normalize spectral_rolloff / sample_rate, # Normalize spectral_spread / sample_rate, # Normalize zcr, rms, np.max(pos_magnitude), np.mean(pos_magnitude), np.std(pos_magnitude), np.sum(pos_magnitude), np.var(pos_magnitude) ] # Add frequency band energies (10 bands) band_energies = [] n_bands = 10 for i in range(n_bands): start_idx = i * len(pos_magnitude) // n_bands end_idx = (i + 1) * len(pos_magnitude) // n_bands band_energy = np.sum(pos_magnitude[start_idx:end_idx]) band_energies.append(band_energy) features.extend(band_energies) return features def _extract_temporal_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list: """Extract temporal features using only numpy""" # Basic statistics mean_val = np.mean(audio_array) std_val = np.std(audio_array) skew_val = np.mean(((audio_array - mean_val) / std_val) ** 3) kurtosis_val = np.mean(((audio_array - mean_val) / std_val) ** 4) # Energy-based features energy = np.sum(audio_array ** 2) power = energy / len(audio_array) # Envelope features envelope = np.abs(audio_array) envelope_mean = np.mean(envelope) envelope_std = np.std(envelope) # Attack/decay characteristics envelope_diff = np.diff(envelope) attack_time = np.mean(envelope_diff[envelope_diff > 0]) decay_time = np.mean(-envelope_diff[envelope_diff < 0]) # Peak characteristics peaks = np.where(np.diff(np.sign(np.diff(audio_array))) < 0)[0] peak_density = len(peaks) / len(audio_array) # Dynamics dynamic_range = np.max(envelope) - np.min(envelope) features = [ mean_val, std_val, skew_val, kurtosis_val, energy, power, envelope_mean, envelope_std, attack_time if np.isfinite(attack_time) else 0.0, decay_time if np.isfinite(decay_time) else 0.0, peak_density, dynamic_range, np.percentile(envelope, 25), np.percentile(envelope, 75), np.median(envelope) ] return features def _extract_frequency_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list: """Extract frequency domain features using only numpy""" # Short-time analysis hop_length = 512 frame_length = 2048 features = [] # Process in overlapping windows n_frames = (len(audio_array) - frame_length) // hop_length + 1 frame_features = [] for i in range(0, min(n_frames, 50)): # Limit to 50 frames for performance start = i * hop_length end = start + frame_length if end > len(audio_array): break frame = audio_array[start:end] # Apply window window = np.hanning(len(frame)) windowed_frame = frame * window # FFT fft = np.fft.fft(windowed_frame) magnitude = np.abs(fft) # Frame-level features frame_energy = np.sum(magnitude ** 2) frame_centroid = np.sum(np.arange(len(magnitude)) * magnitude) / np.sum(magnitude) frame_features.append([frame_energy, frame_centroid]) if frame_features: frame_features = np.array(frame_features) # Aggregate features across frames features.extend([ np.mean(frame_features[:, 0]), # Mean energy np.std(frame_features[:, 0]), # Energy std np.mean(frame_features[:, 1]), # Mean centroid np.std(frame_features[:, 1]), # Centroid std np.max(frame_features[:, 0]), # Max energy np.min(frame_features[:, 0]), # Min energy np.median(frame_features[:, 0]), # Median energy np.var(frame_features[:, 0]), # Energy variance np.mean(np.diff(frame_features[:, 0])), # Energy delta np.std(np.diff(frame_features[:, 0])) # Energy delta std ]) else: features.extend(np.zeros(10)) # Add some basic harmonic features fft_full = np.fft.fft(audio_array) magnitude_full = np.abs(fft_full) # Find fundamental frequency (simple peak detection) freqs = np.fft.fftfreq(len(audio_array), 1/sample_rate) pos_freqs = freqs[:len(freqs)//2] pos_magnitude = magnitude_full[:len(magnitude_full)//2] # Harmonic features peak_idx = np.argmax(pos_magnitude) fundamental_freq = pos_freqs[peak_idx] # Harmonic ratios (simple approximation) harmonic_features = [] for harmonic in [2, 3, 4, 5]: target_freq = fundamental_freq * harmonic if target_freq < sample_rate / 2: target_idx = np.argmin(np.abs(pos_freqs - target_freq)) harmonic_ratio = pos_magnitude[target_idx] / pos_magnitude[peak_idx] harmonic_features.append(harmonic_ratio) else: harmonic_features.append(0.0) features.extend(harmonic_features) # Add more frequency-based features features.extend([ fundamental_freq / sample_rate, # Normalized fundamental np.sum(pos_magnitude > np.mean(pos_magnitude)), # Number of significant peaks np.sum(pos_magnitude) / len(pos_magnitude), # Average magnitude np.std(pos_magnitude), # Magnitude std np.max(pos_magnitude) / np.mean(pos_magnitude), # Peak prominence np.sum(pos_magnitude[:len(pos_magnitude)//4]), # Low freq energy np.sum(pos_magnitude[len(pos_magnitude)//4:len(pos_magnitude)//2]), # Mid freq energy np.sum(pos_magnitude[len(pos_magnitude)//2:3*len(pos_magnitude)//4]), # High freq energy np.sum(pos_magnitude[3*len(pos_magnitude)//4:]), # Very high freq energy np.corrcoef(pos_magnitude[:-1], pos_magnitude[1:])[0,1] if len(pos_magnitude) > 1 else 0.0, # Spectral autocorr np.sum(np.diff(pos_magnitude) > 0) / len(pos_magnitude) # Spectral flux ]) return features def _estimate_tempo_numpy(self, audio_array: np.ndarray, sample_rate: int) -> float: """Simple tempo estimation using numpy only""" try: # Simple envelope-based tempo detection hop_length = 512 frame_length = 2048 # Calculate envelope envelope = np.abs(audio_array) # Downsample envelope n_frames = (len(envelope) - frame_length) // hop_length + 1 envelope_frames = [] for i in range(0, min(n_frames, 1000)): # Limit frames start = i * hop_length end = start + frame_length if end > len(envelope): break frame_energy = np.mean(envelope[start:end]) envelope_frames.append(frame_energy) if len(envelope_frames) < 10: return 120.0 envelope_frames = np.array(envelope_frames) # Find peaks in envelope from scipy.signal import find_peaks peaks, _ = find_peaks(envelope_frames, height=np.mean(envelope_frames)) if len(peaks) > 2: # Calculate intervals between peaks peak_intervals = np.diff(peaks) * hop_length / sample_rate # Filter reasonable intervals (0.2 to 2 seconds) valid_intervals = peak_intervals[(peak_intervals > 0.2) & (peak_intervals < 2.0)] if len(valid_intervals) > 0: avg_interval = np.mean(valid_intervals) tempo = 60.0 / avg_interval # Constrain to reasonable range tempo = max(60, min(200, tempo)) return tempo return 120.0 except Exception as e: logger.warning(f"Numpy tempo estimation failed: {e}") return 120.0 def extract_mert_features(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray: """Extract MERT features using the Music Understanding Model""" try: # Load MERT model on demand self._load_mert_model() # Resample to 24kHz for MERT processing if needed if sample_rate != 24000: logger.info(f"Resampling from {sample_rate}Hz to 24kHz for MERT...") from scipy.signal import resample target_length = int(len(audio_array) * 24000 / sample_rate) audio_array_24k = resample(audio_array, target_length) else: audio_array_24k = audio_array # Critical: Convert to float32 for MERT (community checkpoints expect float32) audio_array_24k = audio_array_24k.astype(np.float32, copy=False) # Use multi-window strategy like CLAP to capture variation total_duration = len(audio_array_24k) / 24000 logger.info(f"Processing MERT features for {len(audio_array_24k)} samples at 24kHz ({total_duration:.1f} seconds)") # Multi-window sampling approach (5-second windows as per MERT paper) hop_seconds = 5 # Move window every 5 seconds win_seconds = 5 # Each window is 5 seconds (MERT positional encodings trained on 5s) hop_samples = hop_seconds * 24000 win_samples = win_seconds * 24000 samples = [] if total_duration <= win_seconds: # Short audio: use the entire thing logger.info("Short audio: using entire file for MERT") samples = [audio_array_24k] else: # Multi-window sampling: overlapping windows across entire track logger.info("Multi-window sampling for MERT: extracting overlapping windows") max_offset = int(total_duration - win_seconds) + 1 for t in range(0, max_offset, hop_seconds): start_sample = t * 24000 end_sample = start_sample + win_samples # Ensure we don't go beyond the audio length if end_sample <= len(audio_array_24k): window = audio_array_24k[start_sample:end_sample] samples.append(window) else: # Last window: take what's available window = audio_array_24k[start_sample:] if len(window) >= win_samples // 2: # At least 5 seconds # Pad to full window length padded_window = np.pad(window, (0, win_samples - len(window))) samples.append(padded_window) break logger.info(f"Generated {len(samples)} MERT windows covering entire track") # Process each sample and collect embeddings embeddings = [] for i, sample in enumerate(samples): sample_length = win_samples if len(sample) < sample_length: # Pad short samples with zeros sample = np.pad(sample, (0, sample_length - len(sample))) logger.info(f"Processing MERT sample {i+1}/{len(samples)}") # Process with MERT processor inputs = self.mert_processor( sample, sampling_rate=24000, return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.mert_model(**inputs, output_hidden_states=True) # Get all hidden states (13 layers for MERT-v1-95M) all_hidden_states = outputs.hidden_states # Shape: [13, batch_size, time_steps, 768] # Time reduction: average across time dimension # Shape: [13, batch_size, 768] time_reduced_states = torch.stack([layer.mean(dim=1) for layer in all_hidden_states]) # Take the mean across all layers to get a single 768-dim representation # This follows the typical approach for utterance-level representations window_embedding = time_reduced_states.mean(dim=0).squeeze() # Shape: [768] embeddings.append(window_embedding.cpu().numpy().flatten()) # Average the raw embeddings from all samples (like CLAP) final_embedding = np.mean(embeddings, axis=0) # Ensure finite values mert_features = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0) # L2 normalize norm = np.linalg.norm(mert_features) if norm > 0: mert_features = mert_features / norm else: # If norm is 0, create a small non-zero vector mert_features = np.ones(768, dtype=np.float32) * 0.001 mert_features = mert_features / np.linalg.norm(mert_features) # Final check for finite values mert_features = np.nan_to_num(mert_features, nan=0.0, posinf=0.0, neginf=0.0) logger.info(f"Extracted MERT features: {len(mert_features)} dimensions (norm: {np.linalg.norm(mert_features):.6f})") return mert_features except Exception as e: logger.error(f"Error extracting MERT features: {str(e)}") # Return normalized zero vector if extraction fails zero_features = np.ones(768, dtype=np.float32) * 0.001 return zero_features / np.linalg.norm(zero_features) def fuse_embeddings(self, clap_embedding: np.ndarray, pitch_features: np.ndarray = None, mert_features: np.ndarray = None) -> np.ndarray: """Fuse CLAP, pitch-aware, and/or MERT features based on fusion mode""" if FUSION_MODE == "VECTOR_CONCAT": # Normalize before and after concatenation (critical for proper similarity computation) embeddings_to_fuse = [] # Step 1: Normalize each embedding individually (clap_u = normalize(clap)) # Always include CLAP clap_u = clap_embedding / np.linalg.norm(clap_embedding) embeddings_to_fuse.append(clap_u) # Add MERT if available if mert_features is not None: mert_u = mert_features / np.linalg.norm(mert_features) embeddings_to_fuse.append(mert_u) # Add pitch features if available if pitch_features is not None: pitch_u = pitch_features / np.linalg.norm(pitch_features) embeddings_to_fuse.append(pitch_u) # Step 2: Concatenate normalized embeddings (fused = cat([clap_u, mert_u])) fused = np.concatenate(embeddings_to_fuse) # Step 3: Normalize the concatenated result (fused = normalize(fused)) # Critical: Without this, similarity is inflated by √2 (0.83 → 0.93) fused_norm = np.linalg.norm(fused) if fused_norm > 0: fused = fused / fused_norm else: logger.warning("Zero norm in fused embedding - creating fallback vector") fused = np.ones_like(fused) * 0.001 fused = fused / np.linalg.norm(fused) # Verify norm is 1.0 (assert abs(fused.norm() - 1.0) < 1e-6) final_norm = np.linalg.norm(fused) if abs(final_norm - 1.0) > 1e-6: logger.warning(f"Fused embedding norm is {final_norm:.8f}, not 1.0 - normalization issue!") # Ensure finite values for JSON serialization fused = np.nan_to_num(fused, nan=0.0, posinf=0.0, neginf=0.0) return fused else: # SCORE_FUSION: return embeddings separately for weighted similarity calculation result = {} result["clap"] = np.nan_to_num(clap_embedding, nan=0.0, posinf=0.0, neginf=0.0) if mert_features is not None: result["mert"] = np.nan_to_num(mert_features, nan=0.0, posinf=0.0, neginf=0.0) if pitch_features is not None: result["pitch"] = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0) return result def is_supported_format(self, image_url: str) -> bool: """Check if image format is supported by PIL/CLIP""" unsupported_extensions = ['.avif', '.heic', '.heif'] url_lower = image_url.lower() return not any(url_lower.endswith(ext) for ext in unsupported_extensions) def detect_image_format(self, content: bytes) -> str: """Detect actual image format from content""" try: # Check for AVIF signature if content.startswith(b'\x00\x00\x00') and b'ftypavif' in content[:32]: return 'AVIF' # Check for HEIC signature elif content.startswith(b'\x00\x00\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]): return 'HEIC' # Check for WebP elif content.startswith(b'RIFF') and b'WEBP' in content[:12]: return 'WebP' # Check for PNG elif content.startswith(b'\x89PNG\r\n\x1a\n'): return 'PNG' # Check for JPEG elif content.startswith(b'\xff\xd8\xff'): return 'JPEG' # Check for GIF elif content.startswith((b'GIF87a', b'GIF89a')): return 'GIF' else: return 'Unknown' except: return 'Unknown' def encode_image(self, image_url: str) -> list: try: logger.info(f"Processing image: {image_url}") # Quick URL-based format check first if not self.is_supported_format(image_url): logger.warning(f"Unsupported format detected from URL: {image_url}") raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)") response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'}) response.raise_for_status() # Detect actual format from content image_format = self.detect_image_format(response.content) logger.info(f"Detected image format: {image_format}") if image_format in ['AVIF', 'HEIC']: logger.warning(f"Unsupported format detected: {image_format} for {image_url}") raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}") try: image = Image.open(io.BytesIO(response.content)) except Exception as e: logger.error(f"PIL cannot open image {image_url}: {str(e)}") if "cannot identify image file" in str(e).lower(): raise HTTPException(status_code=422, detail="Unsupported or corrupted image format") raise if image.mode != 'RGB': logger.info(f"Converting image from {image.mode} to RGB") image = image.convert('RGB') # Resize image if too large to avoid memory issues max_size = 224 # CLIP's expected input size if max(image.size) > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) # Try multiple processor configurations try: # Method 1: Standard CLIP processing inputs = self.clip_processor( images=image, return_tensors="pt", do_rescale=True, do_normalize=True ) except Exception as e1: logger.warning(f"Method 1 failed: {e1}, trying method 2...") try: # Method 2: With padding inputs = self.clip_processor( images=image, return_tensors="pt", padding=True, do_rescale=True, do_normalize=True ) except Exception as e2: logger.warning(f"Method 2 failed: {e2}, trying method 3...") # Method 3: Manual preprocessing inputs = self.clip_processor( images=[image], return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): image_features = self.clip_model.get_image_features(**inputs) image_features = image_features / image_features.norm(dim=-1, keepdim=True) # Ensure safe values for JSON serialization embedding = image_features.cpu().numpy().flatten() embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0) return embedding.tolist() except Exception as e: logger.error(f"Error encoding image {image_url}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") def encode_text(self, text: str) -> list: try: logger.info(f"Processing text: {text[:50]}...") inputs = self.clip_processor(text=[text], return_tensors="pt", padding=True).to(self.device) with torch.no_grad(): text_features = self.clip_model.get_text_features(**inputs) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # Ensure safe values for JSON serialization embedding = text_features.cpu().numpy().flatten() embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0) return embedding.tolist() except Exception as e: logger.error(f"Error encoding text '{text[:50]}...': {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}") def encode_audio(self, audio_url: str) -> list: try: logger.info(f"Processing audio: {audio_url}") # Load CLAP model on demand self._load_clap_model() # Pitch fusion is enabled if pitch-aware features are available if ENABLE_PITCH_FUSION and PITCH_AWARE_AVAILABLE: logger.info("Pitch fusion enabled with librosa features") # Download audio file response = requests.get(audio_url, timeout=60, headers={'User-Agent': 'CLAP-Service/1.0'}) response.raise_for_status() # Save to temporary file with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: tmp_file.write(response.content) tmp_path = tmp_file.name try: # Load audio using soundfile first, then resample with librosa if needed # This avoids the caching issues with librosa.load logger.info("Loading audio with soundfile...") audio_array, original_sr = sf.read(tmp_path) # Convert to mono if needed if len(audio_array.shape) > 1: logger.info("Converting stereo to mono") audio_array = audio_array.mean(axis=1) else: logger.info("Audio is already mono") # Resample to 48kHz for CLAP processing if original_sr != 48000: logger.info(f"Resampling from {original_sr}Hz to 48kHz...") # Use scipy for resampling to avoid librosa caching issues from scipy.signal import resample target_length = int(len(audio_array) * 48000 / original_sr) audio_array_48k = resample(audio_array, target_length) else: audio_array_48k = audio_array # Critical: Convert to float32 for CLAP (community checkpoints expect float32) audio_array_48k = audio_array_48k.astype(np.float32, copy=False) total_duration = len(audio_array_48k) / 48000 logger.info(f"Audio loaded: {len(audio_array_48k)} samples at 48kHz ({total_duration:.1f} seconds)") # Check audio duration limit and truncate both arrays consistently if total_duration > MAX_AUDIO_DURATION_SEC: logger.warning(f"Audio duration {total_duration:.1f}s exceeds limit {MAX_AUDIO_DURATION_SEC}s, truncating...") max_samples_48k = MAX_AUDIO_DURATION_SEC * 48000 max_samples_orig = MAX_AUDIO_DURATION_SEC * original_sr # Truncate both arrays to keep MERT and Pitch processing within limits audio_array_48k = audio_array_48k[:max_samples_48k] audio_array = audio_array[:max_samples_orig] total_duration = MAX_AUDIO_DURATION_SEC logger.info(f"Truncated both arrays to {total_duration:.1f} seconds (48kHz: {len(audio_array_48k)} samples, {original_sr}Hz: {len(audio_array)} samples)") # Process with CLAP (10s windows, 5s hops) clap_embedding = self._process_clap_embeddings(audio_array_48k, total_duration) # Initialize additional features pitch_features = None mert_features = None # Process with MERT features if enabled if ENABLE_MERT_FUSION and MERT_AVAILABLE: try: logger.info("Processing with MERT features for fusion...") mert_features = self.extract_mert_features(audio_array, original_sr) except Exception as e: logger.error(f"MERT feature processing failed: {str(e)}") mert_features = None # Process with pitch-aware features if enabled (can run alongside MERT) if ENABLE_PITCH_FUSION and PITCH_AWARE_AVAILABLE: try: logger.info("Processing with pitch-aware features for fusion...") pitch_features = self._process_pitch_features(audio_array, original_sr, total_duration) except Exception as e: logger.error(f"Pitch feature processing failed: {str(e)}") pitch_features = None # Handle fusion based on what features are available if mert_features is not None or pitch_features is not None: if FUSION_MODE == "VECTOR_CONCAT": final_embedding = self.fuse_embeddings(clap_embedding, pitch_features, mert_features) if mert_features is not None: logger.info(f"Fused embedding dimensions: {len(final_embedding)} (CLAP 512 + MERT 768)") elif pitch_features is not None: logger.info(f"Fused embedding dimensions: {len(final_embedding)} (CLAP 512 + Pitch 85)") return final_embedding.tolist() else: # SCORE_FUSION: return embeddings separately logger.info("Score fusion mode: returning separate embeddings") result = self.fuse_embeddings(clap_embedding, pitch_features, mert_features) result["fusion_mode"] = "SCORE_FUSION" result["fusion_alpha"] = FUSION_ALPHA # Convert to lists for JSON serialization for key in result: if isinstance(result[key], np.ndarray): result[key] = result[key].tolist() return result else: # CLAP-only processing (backwards compatible) logger.info("Using CLAP-only processing") return clap_embedding.tolist() finally: # Clean up temp file if os.path.exists(tmp_path): os.unlink(tmp_path) except Exception as e: logger.error(f"Error encoding audio {audio_url}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}") def _process_clap_embeddings(self, audio_array: np.ndarray, total_duration: float) -> np.ndarray: """Process audio with CLAP using 10s windows and 5s hops""" # Multi-window sampling approach for better discrimination hop_seconds = 5 # Move window every 5 seconds win_seconds = 10 # Each window is 10 seconds hop_samples = hop_seconds * 48000 win_samples = win_seconds * 48000 samples = [] if total_duration <= win_seconds: # Short audio: use the entire thing logger.info("Short audio: using entire file for CLAP") samples = [audio_array] else: # Multi-window sampling: overlapping windows across entire track logger.info("Multi-window sampling for CLAP: extracting overlapping windows") max_offset = int(total_duration - win_seconds) + 1 for t in range(0, max_offset, hop_seconds): start_sample = t * 48000 end_sample = start_sample + win_samples # Ensure we don't go beyond the audio length if end_sample <= len(audio_array): window = audio_array[start_sample:end_sample] samples.append(window) else: # Last window: take what's available window = audio_array[start_sample:] if len(window) >= win_samples // 2: # At least 5 seconds # Pad to full window length padded_window = np.pad(window, (0, win_samples - len(window))) samples.append(padded_window) break logger.info(f"Generated {len(samples)} CLAP windows covering entire track") # Process each sample and collect embeddings embeddings = [] for i, sample in enumerate(samples): sample_length = win_samples if len(sample) < sample_length: # Pad short samples with zeros sample = np.pad(sample, (0, sample_length - len(sample))) logger.info(f"Processing CLAP sample {i+1}/{len(samples)}") # Process with CLAP inputs = self.clap_processor( audios=sample, sampling_rate=48000, return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): audio_features = self.clap_model.get_audio_features(**inputs) window_vec = audio_features.squeeze(0) # 512-D, no L2 embeddings.append(window_vec.cpu().numpy().flatten()) # Average the raw embeddings from all samples final_embedding = np.mean(embeddings, axis=0) # Ensure finite values before normalization final_embedding = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0) # Ensure proper L2 normalization for cosine similarity norm = np.linalg.norm(final_embedding) if norm > 0: final_embedding = final_embedding / norm else: logger.warning("Zero norm CLAP embedding detected") # Create a small normalized vector instead of zero final_embedding = np.ones_like(final_embedding) * 0.001 final_embedding = final_embedding / np.linalg.norm(final_embedding) # Final safety check for JSON serialization final_embedding = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0) # Verify normalization final_norm = np.linalg.norm(final_embedding) logger.info(f"Final CLAP embedding norm: {final_norm:.6f} (should be ~1.0)") return final_embedding def _process_pitch_features(self, audio_array: np.ndarray, original_sr: int, total_duration: float) -> np.ndarray: """Process audio with pitch-aware features using 5s windows and 2s hops""" # Pitch feature processing parameters hop_seconds = 2 # Move window every 2 seconds win_seconds = 5 # Each window is 5 seconds hop_samples = hop_seconds * original_sr win_samples = win_seconds * original_sr samples = [] if total_duration <= win_seconds: # Short audio: use the entire thing logger.info("Short audio: using entire file for pitch features") samples = [audio_array] else: # Multi-window sampling: overlapping windows across entire track logger.info("Multi-window sampling for pitch features: extracting overlapping windows") max_offset = int(total_duration - win_seconds) + 1 for t in range(0, max_offset, hop_seconds): start_sample = t * original_sr end_sample = start_sample + win_samples # Ensure we don't go beyond the audio length if end_sample <= len(audio_array): window = audio_array[start_sample:end_sample] samples.append(window) else: # Last window: take what's available window = audio_array[start_sample:] if len(window) >= win_samples // 2: # At least 2.5 seconds # Pad to full window length padded_window = np.pad(window, (0, win_samples - len(window))) samples.append(padded_window) break logger.info(f"Generated {len(samples)} pitch feature windows covering entire track") # Process each sample and collect features feature_vectors = [] for i, sample in enumerate(samples): sample_length = win_samples if len(sample) < sample_length: # Pad short samples with zeros sample = np.pad(sample, (0, sample_length - len(sample))) logger.info(f"Processing pitch features sample {i+1}/{len(samples)}") # Extract pitch features pitch_features = self.extract_pitch_features(sample, original_sr) feature_vectors.append(pitch_features) # Average the raw feature vectors from all samples final_features = np.mean(feature_vectors, axis=0) # Ensure finite values before normalization final_features = np.nan_to_num(final_features, nan=0.0, posinf=0.0, neginf=0.0) # Ensure proper L2 normalization for cosine similarity norm = np.linalg.norm(final_features) if norm > 0: final_features = final_features / norm else: logger.warning("Zero norm pitch features detected") # Create a small normalized vector instead of zero final_features = np.ones_like(final_features) * 0.001 final_features = final_features / np.linalg.norm(final_features) # Final safety check for JSON serialization final_features = np.nan_to_num(final_features, nan=0.0, posinf=0.0, neginf=0.0) # Verify normalization final_norm = np.linalg.norm(final_features) logger.info(f"Final pitch features norm: {final_norm:.6f} (should be ~1.0)") return final_features # Initialize service with error handling logger.info("Initializing CLIP service...") try: clip_service = CLIPService() logger.info("CLIP service initialized successfully!") except Exception as e: logger.error(f"Failed to initialize CLIP service: {str(e)}") logger.error(f"Error details: {type(e).__name__}: {str(e)}") # For now, we'll let the app start but service calls will fail gracefully clip_service = None class ImageRequest(BaseModel): image_url: str class TextRequest(BaseModel): text: str class AudioRequest(BaseModel): audio_url: str @app.get("/") async def root(): return { "message": "CLIP Service API", "version": "1.0.0", "model": "clip-vit-large-patch14", "endpoints": ["/encode/image", "/encode/text", "/encode/audio", "/health"], "status": "ready" if clip_service else "error" } @app.post("/encode/image") async def encode_image(request: ImageRequest): if not clip_service: raise HTTPException(status_code=503, detail="CLIP service not available") embedding = clip_service.encode_image(request.image_url) safe_embedding = sanitize_for_json(embedding) return {"embedding": safe_embedding, "dimensions": len(safe_embedding)} @app.post("/encode/text") async def encode_text(request: TextRequest): if not clip_service: raise HTTPException(status_code=503, detail="CLIP service not available") embedding = clip_service.encode_text(request.text) safe_embedding = sanitize_for_json(embedding) return {"embedding": safe_embedding, "dimensions": len(safe_embedding)} @app.post("/encode/audio") async def encode_audio(request: AudioRequest): if not clip_service: raise HTTPException(status_code=503, detail="CLAP service not available") if not CLAP_AVAILABLE: raise HTTPException(status_code=501, detail="CLAP model not available in this transformers version") embedding = clip_service.encode_audio(request.audio_url) # Handle both single embedding and fusion mode results if isinstance(embedding, dict): # Score fusion mode - sanitize all embeddings safe_embedding = {} dimensions = {} for key, value in embedding.items(): if key in ["clap", "pitch", "mert"] and isinstance(value, list): safe_embedding[key] = sanitize_for_json(value) dimensions[key] = len(safe_embedding[key]) else: safe_embedding[key] = value return {"embedding": safe_embedding, "dimensions": dimensions} else: # Single embedding (CLAP-only or concatenated) safe_embedding = sanitize_for_json(embedding) return {"embedding": safe_embedding, "dimensions": len(safe_embedding)} @app.get("/health") async def health_check(): if not clip_service: return { "status": "unhealthy", "model": "clip-vit-large-patch14", "error": "Service failed to initialize" } health_info = { "status": "healthy", "models": { "clip": "clip-vit-large-patch14", "clap": f"clap-htsat-unfused (lazy loaded, method: {CLAP_METHOD})" if CLAP_AVAILABLE else "not available", "mert": f"MERT-v1-95M (lazy loaded, method: {MERT_METHOD})" if MERT_AVAILABLE else "not available" }, "device": clip_service.device, "service": "ready", "cache_dir": cache_dir } # Add pitch-aware information if PITCH_AWARE_AVAILABLE: health_info["models"]["pitch_aware"] = f"librosa features ({PITCH_METHOD})" # Add fusion information fusion_enabled = ENABLE_PITCH_FUSION or ENABLE_MERT_FUSION if fusion_enabled: health_info["fusion"] = { "enabled": True, "mode": FUSION_MODE, "pitch_fusion_enabled": ENABLE_PITCH_FUSION, "mert_fusion_enabled": ENABLE_MERT_FUSION, "pitch_aware_available": PITCH_AWARE_AVAILABLE, "mert_available": MERT_AVAILABLE } if FUSION_MODE == "SCORE_FUSION": health_info["fusion"]["alpha"] = FUSION_ALPHA else: health_info["fusion"] = { "enabled": False, "mode": "CLAP_ONLY" } return health_info if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860 uvicorn.run(app, host="0.0.0.0", port=port)