import os import tempfile from pathlib import Path from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import CLIPProcessor, CLIPModel try: from transformers import ClapModel, ClapProcessor CLAP_AVAILABLE = True CLAP_METHOD = "transformers" except ImportError as e1: CLAP_AVAILABLE = False CLAP_METHOD = None import torch from PIL import Image import requests import numpy as np import io import logging import librosa import soundfile as sf # Set up cache directories cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache') os.makedirs(cache_dir, exist_ok=True) os.environ['TRANSFORMERS_CACHE'] = cache_dir os.environ['HF_HOME'] = cache_dir os.environ['TORCH_HOME'] = cache_dir # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="CLIP Service", version="1.0.0") # Log CLAP availability after logger is initialized logger.info(f"CLAP availability: {CLAP_AVAILABLE}, method: {CLAP_METHOD}") class CLIPService: def __init__(self): logger.info("Loading CLIP model...") self.clap_model = None self.clap_processor = None try: # Use CPU for Hugging Face free tier self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Load CLIP model with explicit cache directory logger.info("Loading CLIP model from HuggingFace...") self.clip_model = CLIPModel.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=cache_dir, local_files_only=False ).to(self.device) logger.info("Loading CLIP processor...") self.clip_processor = CLIPProcessor.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=cache_dir, local_files_only=False ) logger.info(f"CLIP model loaded successfully on {self.device}") except Exception as e: logger.error(f"Failed to load CLIP model: {str(e)}") logger.error(f"Error type: {type(e).__name__}") raise RuntimeError(f"CLIP model loading failed: {str(e)}") def _load_clap_model(self): """Load CLAP model on demand""" if not CLAP_AVAILABLE: raise RuntimeError("CLAP model not available - transformers version may not support CLAP") if self.clap_model is None: logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...") try: if CLAP_METHOD == "transformers": logger.info("Loading CLAP model from HuggingFace...") self.clap_model = ClapModel.from_pretrained( "laion/clap-htsat-unfused", cache_dir=cache_dir, local_files_only=False ).to(self.device) logger.info("Loading CLAP processor...") self.clap_processor = ClapProcessor.from_pretrained( "laion/clap-htsat-unfused", cache_dir=cache_dir, local_files_only=False ) logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}") except Exception as e: logger.error(f"Failed to load CLAP model: {str(e)}") logger.error(f"Error type: {type(e).__name__}") raise RuntimeError(f"CLAP model loading failed: {str(e)}") def is_supported_format(self, image_url: str) -> bool: """Check if image format is supported by PIL/CLIP""" unsupported_extensions = ['.avif', '.heic', '.heif'] url_lower = image_url.lower() return not any(url_lower.endswith(ext) for ext in unsupported_extensions) def detect_image_format(self, content: bytes) -> str: """Detect actual image format from content""" try: # Check for AVIF signature if content.startswith(b'\x00\x00\x00') and b'ftypavif' in content[:32]: return 'AVIF' # Check for HEIC signature elif content.startswith(b'\x00\x00\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]): return 'HEIC' # Check for WebP elif content.startswith(b'RIFF') and b'WEBP' in content[:12]: return 'WebP' # Check for PNG elif content.startswith(b'\x89PNG\r\n\x1a\n'): return 'PNG' # Check for JPEG elif content.startswith(b'\xff\xd8\xff'): return 'JPEG' # Check for GIF elif content.startswith((b'GIF87a', b'GIF89a')): return 'GIF' else: return 'Unknown' except: return 'Unknown' def encode_image(self, image_url: str) -> list: try: logger.info(f"Processing image: {image_url}") # Quick URL-based format check first if not self.is_supported_format(image_url): logger.warning(f"Unsupported format detected from URL: {image_url}") raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)") response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'}) response.raise_for_status() # Detect actual format from content image_format = self.detect_image_format(response.content) logger.info(f"Detected image format: {image_format}") if image_format in ['AVIF', 'HEIC']: logger.warning(f"Unsupported format detected: {image_format} for {image_url}") raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}") try: image = Image.open(io.BytesIO(response.content)) except Exception as e: logger.error(f"PIL cannot open image {image_url}: {str(e)}") if "cannot identify image file" in str(e).lower(): raise HTTPException(status_code=422, detail="Unsupported or corrupted image format") raise if image.mode != 'RGB': logger.info(f"Converting image from {image.mode} to RGB") image = image.convert('RGB') # Resize image if too large to avoid memory issues max_size = 224 # CLIP's expected input size if max(image.size) > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) # Try multiple processor configurations try: # Method 1: Standard CLIP processing inputs = self.clip_processor( images=image, return_tensors="pt", do_rescale=True, do_normalize=True ) except Exception as e1: logger.warning(f"Method 1 failed: {e1}, trying method 2...") try: # Method 2: With padding inputs = self.clip_processor( images=image, return_tensors="pt", padding=True, do_rescale=True, do_normalize=True ) except Exception as e2: logger.warning(f"Method 2 failed: {e2}, trying method 3...") # Method 3: Manual preprocessing inputs = self.clip_processor( images=[image], return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): image_features = self.clip_model.get_image_features(**inputs) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy().flatten().tolist() except Exception as e: logger.error(f"Error encoding image {image_url}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") def encode_text(self, text: str) -> list: try: logger.info(f"Processing text: {text[:50]}...") inputs = self.clip_processor(text=[text], return_tensors="pt", padding=True).to(self.device) with torch.no_grad(): text_features = self.clip_model.get_text_features(**inputs) text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.cpu().numpy().flatten().tolist() except Exception as e: logger.error(f"Error encoding text '{text[:50]}...': {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}") def encode_audio(self, audio_url: str) -> list: try: logger.info(f"Processing audio: {audio_url}") # Load CLAP model on demand self._load_clap_model() # Download audio file response = requests.get(audio_url, timeout=60, headers={'User-Agent': 'CLAP-Service/1.0'}) response.raise_for_status() # Save to temporary file with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: tmp_file.write(response.content) tmp_path = tmp_file.name try: # Load audio with librosa # CLAP expects 48kHz sampling rate audio_array, sample_rate = librosa.load(tmp_path, sr=48000, mono=True) # Ensure audio is not too long (max 30 seconds for CLAP) max_length = 30 * 48000 # 30 seconds at 48kHz if len(audio_array) > max_length: audio_array = audio_array[:max_length] # Process with CLAP using transformers method inputs = self.clap_processor( audios=audio_array, sampling_rate=48000, return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): audio_features = self.clap_model.get_audio_features(**inputs) audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True) return audio_features.cpu().numpy().flatten().tolist() finally: # Clean up temp file if os.path.exists(tmp_path): os.unlink(tmp_path) except Exception as e: logger.error(f"Error encoding audio {audio_url}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}") # Initialize service with error handling logger.info("Initializing CLIP service...") try: clip_service = CLIPService() logger.info("CLIP service initialized successfully!") except Exception as e: logger.error(f"Failed to initialize CLIP service: {str(e)}") logger.error(f"Error details: {type(e).__name__}: {str(e)}") # For now, we'll let the app start but service calls will fail gracefully clip_service = None class ImageRequest(BaseModel): image_url: str class TextRequest(BaseModel): text: str class AudioRequest(BaseModel): audio_url: str @app.get("/") async def root(): return { "message": "CLIP Service API", "version": "1.0.0", "model": "clip-vit-large-patch14", "endpoints": ["/encode/image", "/encode/text", "/encode/audio", "/health"], "status": "ready" if clip_service else "error" } @app.post("/encode/image") async def encode_image(request: ImageRequest): if not clip_service: raise HTTPException(status_code=503, detail="CLIP service not available") embedding = clip_service.encode_image(request.image_url) return {"embedding": embedding, "dimensions": len(embedding)} @app.post("/encode/text") async def encode_text(request: TextRequest): if not clip_service: raise HTTPException(status_code=503, detail="CLIP service not available") embedding = clip_service.encode_text(request.text) return {"embedding": embedding, "dimensions": len(embedding)} @app.post("/encode/audio") async def encode_audio(request: AudioRequest): if not clip_service: raise HTTPException(status_code=503, detail="CLAP service not available") if not CLAP_AVAILABLE: raise HTTPException(status_code=501, detail="CLAP model not available in this transformers version") embedding = clip_service.encode_audio(request.audio_url) return {"embedding": embedding, "dimensions": len(embedding)} @app.get("/health") async def health_check(): if not clip_service: return { "status": "unhealthy", "model": "clip-vit-large-patch14", "error": "Service failed to initialize" } return { "status": "healthy", "models": { "clip": "clip-vit-large-patch14", "clap": f"clap-htsat-unfused (lazy loaded, method: {CLAP_METHOD})" if CLAP_AVAILABLE else "not available" }, "device": clip_service.device, "service": "ready", "cache_dir": cache_dir } if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860 uvicorn.run(app, host="0.0.0.0", port=port)