from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import CLIPProcessor, CLIPModel, ClapModel, ClapProcessor import torch from PIL import Image import requests import numpy as np import io import logging import librosa import soundfile as sf # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="CLIP Service", version="1.0.0") class CLIPService: def __init__(self): logger.info("Loading CLIP model...") self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") logger.info("CLIP model loaded successfully") logger.info("Loading CLAP model for audio...") self.clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused") self.clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused") logger.info("CLAP model loaded successfully") def encode_image(self, image_url: str) -> list: try: # Enhanced headers for better compatibility with R2 CDN headers = { 'User-Agent': 'CLIP-Service/1.0 (Image-Embedding-Service)', 'Accept': 'image/*', 'Cache-Control': 'no-cache' } logger.info(f"Fetching image from URL: {image_url}") response = requests.get(image_url, timeout=30, headers=headers) response.raise_for_status() # Log successful fetch with content info content_type = response.headers.get('content-type', 'unknown') content_length = len(response.content) logger.info(f"Successfully fetched image: {content_type}, {content_length} bytes") image = Image.open(io.BytesIO(response.content)) if image.mode != 'RGB': logger.info(f"Converting image from {image.mode} to RGB") image = image.convert('RGB') # Log image dimensions for debugging logger.info(f"Processing image: {image.size[0]}x{image.size[1]}") inputs = self.processor(images=image, return_tensors="pt") with torch.no_grad(): image_features = self.model.get_image_features(**inputs) image_features = image_features / image_features.norm(dim=-1, keepdim=True) embedding = image_features.numpy().flatten().tolist() logger.info(f"Generated embedding with {len(embedding)} dimensions") return embedding except requests.exceptions.RequestException as e: logger.error(f"Network error fetching image {image_url}: {str(e)}") if hasattr(e, 'response') and e.response is not None: status_code = e.response.status_code if status_code == 403: raise HTTPException(status_code=403, detail="Access denied to image URL") elif status_code == 404: raise HTTPException(status_code=404, detail="Image not found at URL") elif status_code >= 500: raise HTTPException(status_code=502, detail="Image service temporarily unavailable") raise HTTPException(status_code=500, detail=f"Failed to fetch image: {str(e)}") except Exception as e: logger.error(f"Error encoding image {image_url}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") def encode_text(self, text: str) -> list: try: inputs = self.processor(text=[text], return_tensors="pt", padding=True) with torch.no_grad(): text_features = self.model.get_text_features(**inputs) text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.numpy().flatten().tolist() except Exception as e: logger.error(f"Error encoding text '{text}': {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}") def encode_audio(self, audio_url: str) -> list: try: # Enhanced headers for audio files with MIME whitelist headers = { 'User-Agent': 'CLAP-Service/1.0 (Audio-Embedding-Service)', 'Accept': 'audio/mpeg, audio/wav, audio/mp4, audio/ogg, audio/flac', 'Cache-Control': 'no-cache' } logger.info(f"Fetching audio from URL: {audio_url}") # Increase timeout for large files, but add streaming response response = requests.get(audio_url, timeout=60, headers=headers, stream=True) response.raise_for_status() # Check content type before processing content_type = response.headers.get('content-type', 'unknown') if not content_type.startswith('audio/'): raise ValueError(f"Invalid content type: {content_type}. Expected audio/*") # Check file size before downloading (100MB limit) content_length = response.headers.get('content-length') if content_length and int(content_length) > 100 * 1024 * 1024: raise ValueError(f"Audio file too large: {content_length} bytes. Maximum is 100MB") # Stream content to BytesIO with size limit audio_data = io.BytesIO() total_size = 0 max_size = 100 * 1024 * 1024 # 100MB for chunk in response.iter_content(chunk_size=8192): total_size += len(chunk) if total_size > max_size: raise ValueError("Audio file too large during download") audio_data.write(chunk) audio_data.seek(0) logger.info(f"Successfully fetched audio: {content_type}, {total_size} bytes") # Load audio with duration limit (10 minutes = 600 seconds) MAX_DURATION = 600 # 10 minutes try: # First, get duration without loading full audio duration = librosa.get_duration(path=audio_data) audio_data.seek(0) # Reset stream if duration > MAX_DURATION: raise ValueError(f"Audio duration ({duration:.1f}s) exceeds maximum allowed ({MAX_DURATION}s)") logger.info(f"Audio duration: {duration:.1f} seconds") # Load only first 30 seconds for embedding (CLAP works well with shorter clips) # This reduces memory usage significantly duration_limit = min(30.0, duration) # Load audio with librosa (48kHz is CLAP's expected sample rate) waveform, sample_rate = librosa.load( audio_data, sr=48000, mono=True, duration=duration_limit, offset=0.0 ) logger.info(f"Processing audio: {len(waveform)} samples at {sample_rate}Hz ({duration_limit:.1f}s)") except Exception as e: logger.error(f"Error loading audio file: {str(e)}") raise ValueError(f"Failed to load audio file: {str(e)}") # Process audio through CLAP inputs = self.clap_processor(audios=waveform, return_tensors="pt", sampling_rate=48000) with torch.no_grad(): audio_features = self.clap_model.get_audio_features(**inputs) # Normalize the features audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True) embedding = audio_features.numpy().flatten().tolist() logger.info(f"Generated audio embedding with {len(embedding)} dimensions") return embedding except ValueError as e: # Handle validation errors (file too large, wrong format, etc.) logger.error(f"Validation error for audio {audio_url}: {str(e)}") raise HTTPException(status_code=400, detail=str(e)) except requests.exceptions.RequestException as e: logger.error(f"Network error fetching audio {audio_url}: {str(e)}") if hasattr(e, 'response') and e.response is not None: status_code = e.response.status_code if status_code == 403: raise HTTPException(status_code=403, detail="Access denied to audio URL") elif status_code == 404: raise HTTPException(status_code=404, detail="Audio not found at URL") elif status_code >= 500: raise HTTPException(status_code=502, detail="Audio service temporarily unavailable") raise HTTPException(status_code=500, detail=f"Failed to fetch audio: {str(e)}") except Exception as e: logger.error(f"Error encoding audio {audio_url}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}") def encode_text_for_audio(self, text: str) -> list: """Encode text for cross-modal audio search""" try: inputs = self.clap_processor(text=[text], return_tensors="pt", padding=True) with torch.no_grad(): text_features = self.clap_model.get_text_features(**inputs) text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.numpy().flatten().tolist() except Exception as e: logger.error(f"Error encoding text for audio '{text}': {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode text for audio: {str(e)}") # Initialize service clip_service = CLIPService() class ImageRequest(BaseModel): image_url: str class TextRequest(BaseModel): text: str class AudioRequest(BaseModel): audio_url: str @app.post("/encode/image") async def encode_image(request: ImageRequest): embedding = clip_service.encode_image(request.image_url) return {"embedding": embedding} @app.post("/encode/text") async def encode_text(request: TextRequest): embedding = clip_service.encode_text(request.text) return {"embedding": embedding} @app.post("/encode/audio") async def encode_audio(request: AudioRequest): """Encode audio file to CLAP embedding vector""" embedding = clip_service.encode_audio(request.audio_url) return {"embedding": embedding} @app.post("/encode/text-audio") async def encode_text_for_audio(request: TextRequest): """Encode text for audio similarity search""" embedding = clip_service.encode_text_for_audio(request.text) return {"embedding": embedding} @app.get("/health") async def health_check(): return {"status": "healthy", "model": "clip-vit-large-patch14"} @app.post("/validate/image") async def validate_image_url(request: ImageRequest): """Validate that an image URL is accessible without processing it""" try: headers = { 'User-Agent': 'CLIP-Service/1.0 (Image-Validation)', 'Accept': 'image/*' } response = requests.head(request.image_url, timeout=10, headers=headers) response.raise_for_status() content_type = response.headers.get('content-type', 'unknown') content_length = response.headers.get('content-length', 'unknown') return { "accessible": True, "content_type": content_type, "content_length": content_length, "url": request.image_url } except Exception as e: return { "accessible": False, "error": str(e), "url": request.image_url } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)