strandtest / app.py
rmoxon's picture
Upload app.py
aa95b5f verified
raw
history blame
57.3 kB
import os
import tempfile
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import CLIPProcessor, CLIPModel
try:
from transformers import ClapModel, ClapProcessor
CLAP_AVAILABLE = True
CLAP_METHOD = "transformers"
except ImportError as e1:
CLAP_AVAILABLE = False
CLAP_METHOD = None
# Check MERT availability
try:
from transformers import AutoModel, Wav2Vec2FeatureExtractor
MERT_AVAILABLE = True
MERT_METHOD = "transformers"
except ImportError as e2:
MERT_AVAILABLE = False
MERT_METHOD = None
import torch
import torchaudio
from PIL import Image
import requests
import numpy as np
import io
import logging
import librosa
import soundfile as sf
import scipy.signal
# Set environment to disable librosa caching
os.environ['LIBROSA_CACHE_DIR'] = '/tmp'
os.environ['JOBLIB_TEMP_FOLDER'] = '/tmp'
# Disable librosa caching completely to avoid the __o_fold error
os.environ['LIBROSA_CACHE_LEVEL'] = '0'
os.environ['LIBROSA_CACHE_COMPRESS'] = '0'
# Check pitch-aware model availability
try:
# Try to use a simpler pitch-aware approach with librosa
import librosa
PITCH_AWARE_AVAILABLE = True
PITCH_METHOD = "librosa_chroma"
except ImportError:
PITCH_AWARE_AVAILABLE = False
PITCH_METHOD = None
# Fusion configuration
FUSION_MODE = os.environ.get('FUSION_MODE', 'VECTOR_CONCAT') # VECTOR_CONCAT or SCORE_FUSION
FUSION_ALPHA = float(os.environ.get('FUSION_ALPHA', '0.6')) # Alpha for score fusion
ENABLE_PITCH_FUSION = os.environ.get('ENABLE_PITCH_FUSION', 'false').lower() == 'true' and PITCH_AWARE_AVAILABLE
ENABLE_MERT_FUSION = os.environ.get('ENABLE_MERT_FUSION', 'false').lower() == 'true' and MERT_AVAILABLE
# Audio processing limits
MAX_AUDIO_DURATION_SEC = int(os.environ.get('MAX_AUDIO_DURATION_SEC', '600')) # 10 minutes default
# Set up cache directories
cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache')
os.makedirs(cache_dir, exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_HOME'] = cache_dir
os.environ['TORCH_HOME'] = cache_dir
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="CLIP Service", version="1.0.0")
def sanitize_for_json(embedding):
"""Sanitize embedding for JSON serialization"""
if isinstance(embedding, np.ndarray):
# Ensure finite values and convert to list
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0)
return embedding.tolist()
elif isinstance(embedding, list):
# Check each element for finite values
return [float(x) if np.isfinite(x) else 0.0 for x in embedding]
else:
return embedding
# Log CLAP, MERT, and pitch-aware availability after logger is initialized
logger.info(f"CLAP availability: {CLAP_AVAILABLE}, method: {CLAP_METHOD}")
logger.info(f"MERT availability: {MERT_AVAILABLE}, method: {MERT_METHOD}")
logger.info(f"Pitch-aware availability: {PITCH_AWARE_AVAILABLE}, method: {PITCH_METHOD}")
logger.info(f"Pitch fusion enabled: {ENABLE_PITCH_FUSION}")
logger.info(f"MERT fusion enabled: {ENABLE_MERT_FUSION}")
logger.info(f"Fusion mode: {FUSION_MODE}")
if FUSION_MODE == 'SCORE_FUSION':
logger.info(f"Fusion alpha: {FUSION_ALPHA}")
class CLIPService:
def __init__(self):
logger.info("Loading CLIP model...")
self.clap_model = None
self.clap_processor = None
self.mert_model = None
self.mert_processor = None
# Simple in-memory cache for pitch features keyed by audio hash
# Using a dict avoids the "unhashable type: numpy.ndarray" error we hit with functools.lru_cache
self.pitch_feature_cache: dict[int, np.ndarray] = {}
try:
# Use CPU for Hugging Face free tier
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Load CLIP model with explicit cache directory
logger.info("Loading CLIP model from HuggingFace...")
self.clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
logger.info("Loading CLIP processor...")
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true'
logger.info(f"Using fast processor: {use_fast}")
self.clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=cache_dir,
local_files_only=False,
use_fast=use_fast
)
logger.info(f"CLIP model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load CLIP model: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
raise RuntimeError(f"CLIP model loading failed: {str(e)}")
def _load_clap_model(self):
"""Load CLAP model on demand"""
if not CLAP_AVAILABLE:
raise RuntimeError("CLAP model not available - transformers version may not support CLAP")
if self.clap_model is None:
logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...")
try:
if CLAP_METHOD == "transformers":
logger.info("Loading CLAP model from HuggingFace...")
self.clap_model = ClapModel.from_pretrained(
"laion/clap-htsat-unfused",
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
logger.info("Loading CLAP processor...")
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true'
self.clap_processor = ClapProcessor.from_pretrained(
"laion/clap-htsat-unfused",
cache_dir=cache_dir,
local_files_only=False,
use_fast=use_fast
)
logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}")
except Exception as e:
logger.error(f"Failed to load CLAP model: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
raise RuntimeError(f"CLAP model loading failed: {str(e)}")
def _load_mert_model(self):
"""Load MERT model on demand"""
if not MERT_AVAILABLE:
raise RuntimeError("MERT model not available - transformers version may not support MERT")
if self.mert_model is None:
logger.info(f"Loading MERT model on demand using {MERT_METHOD} method...")
try:
logger.info("Loading MERT model from HuggingFace...")
self.mert_model = AutoModel.from_pretrained(
"m-a-p/MERT-v1-95M",
trust_remote_code=True,
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
# Guard against missing encoder (stale HF cache issue)
if not hasattr(self.mert_model, "encoder"):
raise RuntimeError("MERT weights not loaded - clear HF cache and retry")
logger.info("Loading MERT processor...")
use_fast = os.environ.get('USE_FAST_PROCESSOR', 'true').lower() == 'true'
self.mert_processor = Wav2Vec2FeatureExtractor.from_pretrained(
"m-a-p/MERT-v1-95M",
cache_dir=cache_dir,
local_files_only=False,
use_fast=use_fast
)
logger.info(f"MERT model loaded successfully on {self.device} using {MERT_METHOD}")
except Exception as e:
logger.error(f"Failed to load MERT model: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
raise RuntimeError(f"MERT model loading failed: {str(e)}")
def extract_pitch_features(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray:
"""Extract pitch-aware features using numpy/scipy (avoiding all librosa caching issues) with a lightweight dict cache"""
cache_key = hash(audio_array.tobytes())
# Fast path: return cached result if we've already seen this audio chunk
if cache_key in self.pitch_feature_cache:
return self.pitch_feature_cache[cache_key]
# Slow path: compute fresh features and store them
features = self._extract_pitch_features_impl(audio_array, sample_rate)
self.pitch_feature_cache[cache_key] = features
return features
def _extract_pitch_features_impl(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray:
"""Implementation of pitch feature extraction (cached)"""
try:
# Use pure numpy/scipy implementations to avoid all librosa caching issues
features = []
# Extract basic audio features using numpy/scipy only
try:
# Basic spectral features using numpy FFT
spectral_features = self._extract_spectral_features_numpy(audio_array, sample_rate)
features.extend(spectral_features)
logger.info("✓ Spectral features extracted (numpy)")
except Exception as e:
logger.warning(f"Spectral feature extraction failed: {e}, using zeros")
features.extend(np.zeros(20)) # 20 spectral features
try:
# Basic temporal features
temporal_features = self._extract_temporal_features_numpy(audio_array, sample_rate)
features.extend(temporal_features)
logger.info("✓ Temporal features extracted (numpy)")
except Exception as e:
logger.warning(f"Temporal feature extraction failed: {e}, using zeros")
features.extend(np.zeros(15)) # 15 temporal features
try:
# Basic frequency domain features
frequency_features = self._extract_frequency_features_numpy(audio_array, sample_rate)
features.extend(frequency_features)
logger.info("✓ Frequency features extracted (numpy)")
except Exception as e:
logger.warning(f"Frequency feature extraction failed: {e}, using zeros")
features.extend(np.zeros(25)) # 25 frequency features
# Simple tempo estimation (fallback)
try:
tempo = self._estimate_tempo_numpy(audio_array, sample_rate)
features.append(tempo)
logger.info(f"✓ Tempo estimated: {tempo:.1f} BPM")
except Exception as e:
logger.warning(f"Tempo estimation failed: {e}, using default")
features.append(120.0) # Default BPM
# Convert to numpy array and check for NaN/inf values
pitch_features = np.array(features, dtype=np.float32)
# Replace any NaN or inf values with 0
pitch_features = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure we have the expected 85 dimensions
if len(pitch_features) < 85:
# Pad with zeros if needed
padding = np.zeros(85 - len(pitch_features), dtype=np.float32)
pitch_features = np.concatenate([pitch_features, padding])
elif len(pitch_features) > 85:
# Truncate if too long
pitch_features = pitch_features[:85]
# L2 normalize
norm = np.linalg.norm(pitch_features)
if norm > 0:
pitch_features = pitch_features / norm
else:
# If norm is 0, create a small non-zero vector
pitch_features = np.ones(85, dtype=np.float32) * 0.001
pitch_features = pitch_features / np.linalg.norm(pitch_features)
# Final check for finite values
pitch_features = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0)
# Result is automatically cached by LRU decorator
logger.info(f"Extracted pitch features: {len(pitch_features)} dimensions")
return pitch_features
except Exception as e:
logger.error(f"Error extracting pitch features: {str(e)}")
# Return normalized zero vector if extraction fails
zero_features = np.ones(85, dtype=np.float32) * 0.001
return zero_features / np.linalg.norm(zero_features)
def _extract_spectral_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list:
"""Extract spectral features using only numpy (no librosa)"""
# Compute FFT
fft = np.fft.fft(audio_array)
magnitude = np.abs(fft)
freqs = np.fft.fftfreq(len(audio_array), 1/sample_rate)
# Only use positive frequencies
pos_freqs = freqs[:len(freqs)//2]
pos_magnitude = magnitude[:len(magnitude)//2]
# Spectral centroid
spectral_centroid = np.sum(pos_freqs * pos_magnitude) / np.sum(pos_magnitude)
# Spectral rolloff (95% of energy)
cumsum = np.cumsum(pos_magnitude)
rolloff_idx = np.where(cumsum >= 0.95 * cumsum[-1])[0]
spectral_rolloff = pos_freqs[rolloff_idx[0]] if len(rolloff_idx) > 0 else 0
# Spectral spread
spectral_spread = np.sqrt(np.sum(((pos_freqs - spectral_centroid) ** 2) * pos_magnitude) / np.sum(pos_magnitude))
# Zero crossing rate
zero_crossings = np.where(np.diff(np.sign(audio_array)))[0]
zcr = len(zero_crossings) / len(audio_array)
# RMS energy
rms = np.sqrt(np.mean(audio_array ** 2))
# Basic spectral features
features = [
spectral_centroid / sample_rate, # Normalize
spectral_rolloff / sample_rate, # Normalize
spectral_spread / sample_rate, # Normalize
zcr,
rms,
np.max(pos_magnitude),
np.mean(pos_magnitude),
np.std(pos_magnitude),
np.sum(pos_magnitude),
np.var(pos_magnitude)
]
# Add frequency band energies (10 bands)
band_energies = []
n_bands = 10
for i in range(n_bands):
start_idx = i * len(pos_magnitude) // n_bands
end_idx = (i + 1) * len(pos_magnitude) // n_bands
band_energy = np.sum(pos_magnitude[start_idx:end_idx])
band_energies.append(band_energy)
features.extend(band_energies)
return features
def _extract_temporal_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list:
"""Extract temporal features using only numpy"""
# Basic statistics
mean_val = np.mean(audio_array)
std_val = np.std(audio_array)
skew_val = np.mean(((audio_array - mean_val) / std_val) ** 3)
kurtosis_val = np.mean(((audio_array - mean_val) / std_val) ** 4)
# Energy-based features
energy = np.sum(audio_array ** 2)
power = energy / len(audio_array)
# Envelope features
envelope = np.abs(audio_array)
envelope_mean = np.mean(envelope)
envelope_std = np.std(envelope)
# Attack/decay characteristics
envelope_diff = np.diff(envelope)
attack_time = np.mean(envelope_diff[envelope_diff > 0])
decay_time = np.mean(-envelope_diff[envelope_diff < 0])
# Peak characteristics
peaks = np.where(np.diff(np.sign(np.diff(audio_array))) < 0)[0]
peak_density = len(peaks) / len(audio_array)
# Dynamics
dynamic_range = np.max(envelope) - np.min(envelope)
features = [
mean_val,
std_val,
skew_val,
kurtosis_val,
energy,
power,
envelope_mean,
envelope_std,
attack_time if np.isfinite(attack_time) else 0.0,
decay_time if np.isfinite(decay_time) else 0.0,
peak_density,
dynamic_range,
np.percentile(envelope, 25),
np.percentile(envelope, 75),
np.median(envelope)
]
return features
def _extract_frequency_features_numpy(self, audio_array: np.ndarray, sample_rate: int) -> list:
"""Extract frequency domain features using only numpy"""
# Short-time analysis
hop_length = 512
frame_length = 2048
features = []
# Process in overlapping windows
n_frames = (len(audio_array) - frame_length) // hop_length + 1
frame_features = []
for i in range(0, min(n_frames, 50)): # Limit to 50 frames for performance
start = i * hop_length
end = start + frame_length
if end > len(audio_array):
break
frame = audio_array[start:end]
# Apply window
window = np.hanning(len(frame))
windowed_frame = frame * window
# FFT
fft = np.fft.fft(windowed_frame)
magnitude = np.abs(fft)
# Frame-level features
frame_energy = np.sum(magnitude ** 2)
frame_centroid = np.sum(np.arange(len(magnitude)) * magnitude) / np.sum(magnitude)
frame_features.append([frame_energy, frame_centroid])
if frame_features:
frame_features = np.array(frame_features)
# Aggregate features across frames
features.extend([
np.mean(frame_features[:, 0]), # Mean energy
np.std(frame_features[:, 0]), # Energy std
np.mean(frame_features[:, 1]), # Mean centroid
np.std(frame_features[:, 1]), # Centroid std
np.max(frame_features[:, 0]), # Max energy
np.min(frame_features[:, 0]), # Min energy
np.median(frame_features[:, 0]), # Median energy
np.var(frame_features[:, 0]), # Energy variance
np.mean(np.diff(frame_features[:, 0])), # Energy delta
np.std(np.diff(frame_features[:, 0])) # Energy delta std
])
else:
features.extend(np.zeros(10))
# Add some basic harmonic features
fft_full = np.fft.fft(audio_array)
magnitude_full = np.abs(fft_full)
# Find fundamental frequency (simple peak detection)
freqs = np.fft.fftfreq(len(audio_array), 1/sample_rate)
pos_freqs = freqs[:len(freqs)//2]
pos_magnitude = magnitude_full[:len(magnitude_full)//2]
# Harmonic features
peak_idx = np.argmax(pos_magnitude)
fundamental_freq = pos_freqs[peak_idx]
# Harmonic ratios (simple approximation)
harmonic_features = []
for harmonic in [2, 3, 4, 5]:
target_freq = fundamental_freq * harmonic
if target_freq < sample_rate / 2:
target_idx = np.argmin(np.abs(pos_freqs - target_freq))
harmonic_ratio = pos_magnitude[target_idx] / pos_magnitude[peak_idx]
harmonic_features.append(harmonic_ratio)
else:
harmonic_features.append(0.0)
features.extend(harmonic_features)
# Add more frequency-based features
features.extend([
fundamental_freq / sample_rate, # Normalized fundamental
np.sum(pos_magnitude > np.mean(pos_magnitude)), # Number of significant peaks
np.sum(pos_magnitude) / len(pos_magnitude), # Average magnitude
np.std(pos_magnitude), # Magnitude std
np.max(pos_magnitude) / np.mean(pos_magnitude), # Peak prominence
np.sum(pos_magnitude[:len(pos_magnitude)//4]), # Low freq energy
np.sum(pos_magnitude[len(pos_magnitude)//4:len(pos_magnitude)//2]), # Mid freq energy
np.sum(pos_magnitude[len(pos_magnitude)//2:3*len(pos_magnitude)//4]), # High freq energy
np.sum(pos_magnitude[3*len(pos_magnitude)//4:]), # Very high freq energy
np.corrcoef(pos_magnitude[:-1], pos_magnitude[1:])[0,1] if len(pos_magnitude) > 1 else 0.0, # Spectral autocorr
np.sum(np.diff(pos_magnitude) > 0) / len(pos_magnitude) # Spectral flux
])
return features
def _estimate_tempo_numpy(self, audio_array: np.ndarray, sample_rate: int) -> float:
"""Simple tempo estimation using numpy only"""
try:
# Simple envelope-based tempo detection
hop_length = 512
frame_length = 2048
# Calculate envelope
envelope = np.abs(audio_array)
# Downsample envelope
n_frames = (len(envelope) - frame_length) // hop_length + 1
envelope_frames = []
for i in range(0, min(n_frames, 1000)): # Limit frames
start = i * hop_length
end = start + frame_length
if end > len(envelope):
break
frame_energy = np.mean(envelope[start:end])
envelope_frames.append(frame_energy)
if len(envelope_frames) < 10:
return 120.0
envelope_frames = np.array(envelope_frames)
# Find peaks in envelope
from scipy.signal import find_peaks
peaks, _ = find_peaks(envelope_frames, height=np.mean(envelope_frames))
if len(peaks) > 2:
# Calculate intervals between peaks
peak_intervals = np.diff(peaks) * hop_length / sample_rate
# Filter reasonable intervals (0.2 to 2 seconds)
valid_intervals = peak_intervals[(peak_intervals > 0.2) & (peak_intervals < 2.0)]
if len(valid_intervals) > 0:
avg_interval = np.mean(valid_intervals)
tempo = 60.0 / avg_interval
# Constrain to reasonable range
tempo = max(60, min(200, tempo))
return tempo
return 120.0
except Exception as e:
logger.warning(f"Numpy tempo estimation failed: {e}")
return 120.0
def extract_mert_features(self, audio_array: np.ndarray, sample_rate: int) -> np.ndarray:
"""Extract MERT features using the Music Understanding Model"""
try:
# Load MERT model on demand
self._load_mert_model()
# Resample to 24kHz for MERT processing if needed
if sample_rate != 24000:
logger.info(f"Resampling from {sample_rate}Hz to 24kHz for MERT...")
from scipy.signal import resample
target_length = int(len(audio_array) * 24000 / sample_rate)
audio_array_24k = resample(audio_array, target_length)
else:
audio_array_24k = audio_array
# Critical: Convert to float32 for MERT (community checkpoints expect float32)
audio_array_24k = audio_array_24k.astype(np.float32, copy=False)
# Use multi-window strategy like CLAP to capture variation
total_duration = len(audio_array_24k) / 24000
logger.info(f"Processing MERT features for {len(audio_array_24k)} samples at 24kHz ({total_duration:.1f} seconds)")
# Multi-window sampling approach (5-second windows as per MERT paper)
hop_seconds = 5 # Move window every 5 seconds
win_seconds = 5 # Each window is 5 seconds (MERT positional encodings trained on 5s)
hop_samples = hop_seconds * 24000
win_samples = win_seconds * 24000
samples = []
if total_duration <= win_seconds:
# Short audio: use the entire thing
logger.info("Short audio: using entire file for MERT")
samples = [audio_array_24k]
else:
# Multi-window sampling: overlapping windows across entire track
logger.info("Multi-window sampling for MERT: extracting overlapping windows")
max_offset = int(total_duration - win_seconds) + 1
for t in range(0, max_offset, hop_seconds):
start_sample = t * 24000
end_sample = start_sample + win_samples
# Ensure we don't go beyond the audio length
if end_sample <= len(audio_array_24k):
window = audio_array_24k[start_sample:end_sample]
samples.append(window)
else:
# Last window: take what's available
window = audio_array_24k[start_sample:]
if len(window) >= win_samples // 2: # At least 5 seconds
# Pad to full window length
padded_window = np.pad(window, (0, win_samples - len(window)))
samples.append(padded_window)
break
logger.info(f"Generated {len(samples)} MERT windows covering entire track")
# Process each sample and collect embeddings
embeddings = []
for i, sample in enumerate(samples):
sample_length = win_samples
if len(sample) < sample_length:
# Pad short samples with zeros
sample = np.pad(sample, (0, sample_length - len(sample)))
logger.info(f"Processing MERT sample {i+1}/{len(samples)}")
# Process with MERT processor
inputs = self.mert_processor(
sample,
sampling_rate=24000,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.mert_model(**inputs, output_hidden_states=True)
# Get all hidden states (13 layers for MERT-v1-95M)
all_hidden_states = outputs.hidden_states # Shape: [13, batch_size, time_steps, 768]
# Time reduction: average across time dimension
# Shape: [13, batch_size, 768]
time_reduced_states = torch.stack([layer.mean(dim=1) for layer in all_hidden_states])
# Take the mean across all layers to get a single 768-dim representation
# This follows the typical approach for utterance-level representations
window_embedding = time_reduced_states.mean(dim=0).squeeze() # Shape: [768]
embeddings.append(window_embedding.cpu().numpy().flatten())
# Average the raw embeddings from all samples (like CLAP)
final_embedding = np.mean(embeddings, axis=0)
# Ensure finite values
mert_features = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0)
# L2 normalize
norm = np.linalg.norm(mert_features)
if norm > 0:
mert_features = mert_features / norm
else:
# If norm is 0, create a small non-zero vector
mert_features = np.ones(768, dtype=np.float32) * 0.001
mert_features = mert_features / np.linalg.norm(mert_features)
# Final check for finite values
mert_features = np.nan_to_num(mert_features, nan=0.0, posinf=0.0, neginf=0.0)
logger.info(f"Extracted MERT features: {len(mert_features)} dimensions (norm: {np.linalg.norm(mert_features):.6f})")
return mert_features
except Exception as e:
logger.error(f"Error extracting MERT features: {str(e)}")
# Return normalized zero vector if extraction fails
zero_features = np.ones(768, dtype=np.float32) * 0.001
return zero_features / np.linalg.norm(zero_features)
def fuse_embeddings(self, clap_embedding: np.ndarray, pitch_features: np.ndarray = None, mert_features: np.ndarray = None) -> np.ndarray:
"""Fuse CLAP, pitch-aware, and/or MERT features based on fusion mode"""
if FUSION_MODE == "VECTOR_CONCAT":
# Normalize before and after concatenation (critical for proper similarity computation)
embeddings_to_fuse = []
# Step 1: Normalize each embedding individually (clap_u = normalize(clap))
# Always include CLAP
clap_u = clap_embedding / np.linalg.norm(clap_embedding)
embeddings_to_fuse.append(clap_u)
# Add MERT if available
if mert_features is not None:
mert_u = mert_features / np.linalg.norm(mert_features)
embeddings_to_fuse.append(mert_u)
# Add pitch features if available
if pitch_features is not None:
pitch_u = pitch_features / np.linalg.norm(pitch_features)
embeddings_to_fuse.append(pitch_u)
# Step 2: Concatenate normalized embeddings (fused = cat([clap_u, mert_u]))
fused = np.concatenate(embeddings_to_fuse)
# Step 3: Normalize the concatenated result (fused = normalize(fused))
# Critical: Without this, similarity is inflated by √2 (0.83 → 0.93)
fused_norm = np.linalg.norm(fused)
if fused_norm > 0:
fused = fused / fused_norm
else:
logger.warning("Zero norm in fused embedding - creating fallback vector")
fused = np.ones_like(fused) * 0.001
fused = fused / np.linalg.norm(fused)
# Verify norm is 1.0 (assert abs(fused.norm() - 1.0) < 1e-6)
final_norm = np.linalg.norm(fused)
if abs(final_norm - 1.0) > 1e-6:
logger.warning(f"Fused embedding norm is {final_norm:.8f}, not 1.0 - normalization issue!")
# Ensure finite values for JSON serialization
fused = np.nan_to_num(fused, nan=0.0, posinf=0.0, neginf=0.0)
return fused
else:
# SCORE_FUSION: return embeddings separately for weighted similarity calculation
result = {}
result["clap"] = np.nan_to_num(clap_embedding, nan=0.0, posinf=0.0, neginf=0.0)
if mert_features is not None:
result["mert"] = np.nan_to_num(mert_features, nan=0.0, posinf=0.0, neginf=0.0)
if pitch_features is not None:
result["pitch"] = np.nan_to_num(pitch_features, nan=0.0, posinf=0.0, neginf=0.0)
return result
def is_supported_format(self, image_url: str) -> bool:
"""Check if image format is supported by PIL/CLIP"""
unsupported_extensions = ['.avif', '.heic', '.heif']
url_lower = image_url.lower()
return not any(url_lower.endswith(ext) for ext in unsupported_extensions)
def detect_image_format(self, content: bytes) -> str:
"""Detect actual image format from content"""
try:
# Check for AVIF signature
if content.startswith(b'\x00\x00\x00') and b'ftypavif' in content[:32]:
return 'AVIF'
# Check for HEIC signature
elif content.startswith(b'\x00\x00\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]):
return 'HEIC'
# Check for WebP
elif content.startswith(b'RIFF') and b'WEBP' in content[:12]:
return 'WebP'
# Check for PNG
elif content.startswith(b'\x89PNG\r\n\x1a\n'):
return 'PNG'
# Check for JPEG
elif content.startswith(b'\xff\xd8\xff'):
return 'JPEG'
# Check for GIF
elif content.startswith((b'GIF87a', b'GIF89a')):
return 'GIF'
else:
return 'Unknown'
except:
return 'Unknown'
def encode_image(self, image_url: str) -> list:
try:
logger.info(f"Processing image: {image_url}")
# Quick URL-based format check first
if not self.is_supported_format(image_url):
logger.warning(f"Unsupported format detected from URL: {image_url}")
raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)")
response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'})
response.raise_for_status()
# Detect actual format from content
image_format = self.detect_image_format(response.content)
logger.info(f"Detected image format: {image_format}")
if image_format in ['AVIF', 'HEIC']:
logger.warning(f"Unsupported format detected: {image_format} for {image_url}")
raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}")
try:
image = Image.open(io.BytesIO(response.content))
except Exception as e:
logger.error(f"PIL cannot open image {image_url}: {str(e)}")
if "cannot identify image file" in str(e).lower():
raise HTTPException(status_code=422, detail="Unsupported or corrupted image format")
raise
if image.mode != 'RGB':
logger.info(f"Converting image from {image.mode} to RGB")
image = image.convert('RGB')
# Resize image if too large to avoid memory issues
max_size = 224 # CLIP's expected input size
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Try multiple processor configurations
try:
# Method 1: Standard CLIP processing
inputs = self.clip_processor(
images=image,
return_tensors="pt",
do_rescale=True,
do_normalize=True
)
except Exception as e1:
logger.warning(f"Method 1 failed: {e1}, trying method 2...")
try:
# Method 2: With padding
inputs = self.clip_processor(
images=image,
return_tensors="pt",
padding=True,
do_rescale=True,
do_normalize=True
)
except Exception as e2:
logger.warning(f"Method 2 failed: {e2}, trying method 3...")
# Method 3: Manual preprocessing
inputs = self.clip_processor(
images=[image],
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
image_features = self.clip_model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Ensure safe values for JSON serialization
embedding = image_features.cpu().numpy().flatten()
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0)
return embedding.tolist()
except Exception as e:
logger.error(f"Error encoding image {image_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
def encode_text(self, text: str) -> list:
try:
logger.info(f"Processing text: {text[:50]}...")
inputs = self.clip_processor(text=[text], return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
text_features = self.clip_model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Ensure safe values for JSON serialization
embedding = text_features.cpu().numpy().flatten()
embedding = np.nan_to_num(embedding, nan=0.0, posinf=0.0, neginf=0.0)
return embedding.tolist()
except Exception as e:
logger.error(f"Error encoding text '{text[:50]}...': {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}")
def encode_audio(self, audio_url: str) -> list:
try:
logger.info(f"Processing audio: {audio_url}")
# Load CLAP model on demand
self._load_clap_model()
# Pitch fusion is enabled if pitch-aware features are available
if ENABLE_PITCH_FUSION and PITCH_AWARE_AVAILABLE:
logger.info("Pitch fusion enabled with librosa features")
# Download audio file
response = requests.get(audio_url, timeout=60, headers={'User-Agent': 'CLAP-Service/1.0'})
response.raise_for_status()
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
tmp_file.write(response.content)
tmp_path = tmp_file.name
try:
# Load audio using soundfile first, then resample with librosa if needed
# This avoids the caching issues with librosa.load
logger.info("Loading audio with soundfile...")
audio_array, original_sr = sf.read(tmp_path)
# Convert to mono if needed
if len(audio_array.shape) > 1:
logger.info("Converting stereo to mono")
audio_array = audio_array.mean(axis=1)
else:
logger.info("Audio is already mono")
# Resample to 48kHz for CLAP processing
if original_sr != 48000:
logger.info(f"Resampling from {original_sr}Hz to 48kHz...")
# Use scipy for resampling to avoid librosa caching issues
from scipy.signal import resample
target_length = int(len(audio_array) * 48000 / original_sr)
audio_array_48k = resample(audio_array, target_length)
else:
audio_array_48k = audio_array
# Critical: Convert to float32 for CLAP (community checkpoints expect float32)
audio_array_48k = audio_array_48k.astype(np.float32, copy=False)
total_duration = len(audio_array_48k) / 48000
logger.info(f"Audio loaded: {len(audio_array_48k)} samples at 48kHz ({total_duration:.1f} seconds)")
# Check audio duration limit and truncate both arrays consistently
if total_duration > MAX_AUDIO_DURATION_SEC:
logger.warning(f"Audio duration {total_duration:.1f}s exceeds limit {MAX_AUDIO_DURATION_SEC}s, truncating...")
max_samples_48k = MAX_AUDIO_DURATION_SEC * 48000
max_samples_orig = MAX_AUDIO_DURATION_SEC * original_sr
# Truncate both arrays to keep MERT and Pitch processing within limits
audio_array_48k = audio_array_48k[:max_samples_48k]
audio_array = audio_array[:max_samples_orig]
total_duration = MAX_AUDIO_DURATION_SEC
logger.info(f"Truncated both arrays to {total_duration:.1f} seconds (48kHz: {len(audio_array_48k)} samples, {original_sr}Hz: {len(audio_array)} samples)")
# Process with CLAP (10s windows, 5s hops)
clap_embedding = self._process_clap_embeddings(audio_array_48k, total_duration)
# Initialize additional features
pitch_features = None
mert_features = None
# Process with MERT features if enabled
if ENABLE_MERT_FUSION and MERT_AVAILABLE:
try:
logger.info("Processing with MERT features for fusion...")
mert_features = self.extract_mert_features(audio_array, original_sr)
except Exception as e:
logger.error(f"MERT feature processing failed: {str(e)}")
mert_features = None
# Process with pitch-aware features if enabled (can run alongside MERT)
if ENABLE_PITCH_FUSION and PITCH_AWARE_AVAILABLE:
try:
logger.info("Processing with pitch-aware features for fusion...")
pitch_features = self._process_pitch_features(audio_array, original_sr, total_duration)
except Exception as e:
logger.error(f"Pitch feature processing failed: {str(e)}")
pitch_features = None
# Handle fusion based on what features are available
if mert_features is not None or pitch_features is not None:
if FUSION_MODE == "VECTOR_CONCAT":
final_embedding = self.fuse_embeddings(clap_embedding, pitch_features, mert_features)
if mert_features is not None:
logger.info(f"Fused embedding dimensions: {len(final_embedding)} (CLAP 512 + MERT 768)")
elif pitch_features is not None:
logger.info(f"Fused embedding dimensions: {len(final_embedding)} (CLAP 512 + Pitch 85)")
return final_embedding.tolist()
else:
# SCORE_FUSION: return embeddings separately
logger.info("Score fusion mode: returning separate embeddings")
result = self.fuse_embeddings(clap_embedding, pitch_features, mert_features)
result["fusion_mode"] = "SCORE_FUSION"
result["fusion_alpha"] = FUSION_ALPHA
# Convert to lists for JSON serialization
for key in result:
if isinstance(result[key], np.ndarray):
result[key] = result[key].tolist()
return result
else:
# CLAP-only processing (backwards compatible)
logger.info("Using CLAP-only processing")
return clap_embedding.tolist()
finally:
# Clean up temp file
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
logger.error(f"Error encoding audio {audio_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}")
def _process_clap_embeddings(self, audio_array: np.ndarray, total_duration: float) -> np.ndarray:
"""Process audio with CLAP using 10s windows and 5s hops"""
# Multi-window sampling approach for better discrimination
hop_seconds = 5 # Move window every 5 seconds
win_seconds = 10 # Each window is 10 seconds
hop_samples = hop_seconds * 48000
win_samples = win_seconds * 48000
samples = []
if total_duration <= win_seconds:
# Short audio: use the entire thing
logger.info("Short audio: using entire file for CLAP")
samples = [audio_array]
else:
# Multi-window sampling: overlapping windows across entire track
logger.info("Multi-window sampling for CLAP: extracting overlapping windows")
max_offset = int(total_duration - win_seconds) + 1
for t in range(0, max_offset, hop_seconds):
start_sample = t * 48000
end_sample = start_sample + win_samples
# Ensure we don't go beyond the audio length
if end_sample <= len(audio_array):
window = audio_array[start_sample:end_sample]
samples.append(window)
else:
# Last window: take what's available
window = audio_array[start_sample:]
if len(window) >= win_samples // 2: # At least 5 seconds
# Pad to full window length
padded_window = np.pad(window, (0, win_samples - len(window)))
samples.append(padded_window)
break
logger.info(f"Generated {len(samples)} CLAP windows covering entire track")
# Process each sample and collect embeddings
embeddings = []
for i, sample in enumerate(samples):
sample_length = win_samples
if len(sample) < sample_length:
# Pad short samples with zeros
sample = np.pad(sample, (0, sample_length - len(sample)))
logger.info(f"Processing CLAP sample {i+1}/{len(samples)}")
# Process with CLAP
inputs = self.clap_processor(
audios=sample,
sampling_rate=48000,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
audio_features = self.clap_model.get_audio_features(**inputs)
window_vec = audio_features.squeeze(0) # 512-D, no L2
embeddings.append(window_vec.cpu().numpy().flatten())
# Average the raw embeddings from all samples
final_embedding = np.mean(embeddings, axis=0)
# Ensure finite values before normalization
final_embedding = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure proper L2 normalization for cosine similarity
norm = np.linalg.norm(final_embedding)
if norm > 0:
final_embedding = final_embedding / norm
else:
logger.warning("Zero norm CLAP embedding detected")
# Create a small normalized vector instead of zero
final_embedding = np.ones_like(final_embedding) * 0.001
final_embedding = final_embedding / np.linalg.norm(final_embedding)
# Final safety check for JSON serialization
final_embedding = np.nan_to_num(final_embedding, nan=0.0, posinf=0.0, neginf=0.0)
# Verify normalization
final_norm = np.linalg.norm(final_embedding)
logger.info(f"Final CLAP embedding norm: {final_norm:.6f} (should be ~1.0)")
return final_embedding
def _process_pitch_features(self, audio_array: np.ndarray, original_sr: int, total_duration: float) -> np.ndarray:
"""Process audio with pitch-aware features using 5s windows and 2s hops"""
# Pitch feature processing parameters
hop_seconds = 2 # Move window every 2 seconds
win_seconds = 5 # Each window is 5 seconds
hop_samples = hop_seconds * original_sr
win_samples = win_seconds * original_sr
samples = []
if total_duration <= win_seconds:
# Short audio: use the entire thing
logger.info("Short audio: using entire file for pitch features")
samples = [audio_array]
else:
# Multi-window sampling: overlapping windows across entire track
logger.info("Multi-window sampling for pitch features: extracting overlapping windows")
max_offset = int(total_duration - win_seconds) + 1
for t in range(0, max_offset, hop_seconds):
start_sample = t * original_sr
end_sample = start_sample + win_samples
# Ensure we don't go beyond the audio length
if end_sample <= len(audio_array):
window = audio_array[start_sample:end_sample]
samples.append(window)
else:
# Last window: take what's available
window = audio_array[start_sample:]
if len(window) >= win_samples // 2: # At least 2.5 seconds
# Pad to full window length
padded_window = np.pad(window, (0, win_samples - len(window)))
samples.append(padded_window)
break
logger.info(f"Generated {len(samples)} pitch feature windows covering entire track")
# Process each sample and collect features
feature_vectors = []
for i, sample in enumerate(samples):
sample_length = win_samples
if len(sample) < sample_length:
# Pad short samples with zeros
sample = np.pad(sample, (0, sample_length - len(sample)))
logger.info(f"Processing pitch features sample {i+1}/{len(samples)}")
# Extract pitch features
pitch_features = self.extract_pitch_features(sample, original_sr)
feature_vectors.append(pitch_features)
# Average the raw feature vectors from all samples
final_features = np.mean(feature_vectors, axis=0)
# Ensure finite values before normalization
final_features = np.nan_to_num(final_features, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure proper L2 normalization for cosine similarity
norm = np.linalg.norm(final_features)
if norm > 0:
final_features = final_features / norm
else:
logger.warning("Zero norm pitch features detected")
# Create a small normalized vector instead of zero
final_features = np.ones_like(final_features) * 0.001
final_features = final_features / np.linalg.norm(final_features)
# Final safety check for JSON serialization
final_features = np.nan_to_num(final_features, nan=0.0, posinf=0.0, neginf=0.0)
# Verify normalization
final_norm = np.linalg.norm(final_features)
logger.info(f"Final pitch features norm: {final_norm:.6f} (should be ~1.0)")
return final_features
# Initialize service with error handling
logger.info("Initializing CLIP service...")
try:
clip_service = CLIPService()
logger.info("CLIP service initialized successfully!")
except Exception as e:
logger.error(f"Failed to initialize CLIP service: {str(e)}")
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
# For now, we'll let the app start but service calls will fail gracefully
clip_service = None
class ImageRequest(BaseModel):
image_url: str
class TextRequest(BaseModel):
text: str
class AudioRequest(BaseModel):
audio_url: str
@app.get("/")
async def root():
return {
"message": "CLIP Service API",
"version": "1.0.0",
"model": "clip-vit-large-patch14",
"endpoints": ["/encode/image", "/encode/text", "/encode/audio", "/health"],
"status": "ready" if clip_service else "error"
}
@app.post("/encode/image")
async def encode_image(request: ImageRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLIP service not available")
embedding = clip_service.encode_image(request.image_url)
safe_embedding = sanitize_for_json(embedding)
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)}
@app.post("/encode/text")
async def encode_text(request: TextRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLIP service not available")
embedding = clip_service.encode_text(request.text)
safe_embedding = sanitize_for_json(embedding)
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)}
@app.post("/encode/audio")
async def encode_audio(request: AudioRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLAP service not available")
if not CLAP_AVAILABLE:
raise HTTPException(status_code=501, detail="CLAP model not available in this transformers version")
embedding = clip_service.encode_audio(request.audio_url)
# Handle both single embedding and fusion mode results
if isinstance(embedding, dict):
# Score fusion mode - sanitize all embeddings
safe_embedding = {}
dimensions = {}
for key, value in embedding.items():
if key in ["clap", "pitch", "mert"] and isinstance(value, list):
safe_embedding[key] = sanitize_for_json(value)
dimensions[key] = len(safe_embedding[key])
else:
safe_embedding[key] = value
return {"embedding": safe_embedding, "dimensions": dimensions}
else:
# Single embedding (CLAP-only or concatenated)
safe_embedding = sanitize_for_json(embedding)
return {"embedding": safe_embedding, "dimensions": len(safe_embedding)}
@app.get("/health")
async def health_check():
if not clip_service:
return {
"status": "unhealthy",
"model": "clip-vit-large-patch14",
"error": "Service failed to initialize"
}
health_info = {
"status": "healthy",
"models": {
"clip": "clip-vit-large-patch14",
"clap": f"clap-htsat-unfused (lazy loaded, method: {CLAP_METHOD})" if CLAP_AVAILABLE else "not available",
"mert": f"MERT-v1-95M (lazy loaded, method: {MERT_METHOD})" if MERT_AVAILABLE else "not available"
},
"device": clip_service.device,
"service": "ready",
"cache_dir": cache_dir
}
# Add pitch-aware information
if PITCH_AWARE_AVAILABLE:
health_info["models"]["pitch_aware"] = f"librosa features ({PITCH_METHOD})"
# Add fusion information
fusion_enabled = ENABLE_PITCH_FUSION or ENABLE_MERT_FUSION
if fusion_enabled:
health_info["fusion"] = {
"enabled": True,
"mode": FUSION_MODE,
"pitch_fusion_enabled": ENABLE_PITCH_FUSION,
"mert_fusion_enabled": ENABLE_MERT_FUSION,
"pitch_aware_available": PITCH_AWARE_AVAILABLE,
"mert_available": MERT_AVAILABLE
}
if FUSION_MODE == "SCORE_FUSION":
health_info["fusion"]["alpha"] = FUSION_ALPHA
else:
health_info["fusion"] = {
"enabled": False,
"mode": "CLAP_ONLY"
}
return health_info
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860
uvicorn.run(app, host="0.0.0.0", port=port)