import os import tempfile from pathlib import Path from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import CLIPProcessor, CLIPModel import torch from PIL import Image import requests import numpy as np import io import logging # Set up cache directories cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache') os.makedirs(cache_dir, exist_ok=True) os.environ['TRANSFORMERS_CACHE'] = cache_dir os.environ['HF_HOME'] = cache_dir os.environ['TORCH_HOME'] = cache_dir # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="CLIP Service", version="1.0.0") class CLIPService: def __init__(self): logger.info("Loading CLIP model...") try: # Use CPU for Hugging Face free tier self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Load model with explicit cache directory self.model = CLIPModel.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=cache_dir, local_files_only=False ).to(self.device) self.processor = CLIPProcessor.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=cache_dir, local_files_only=False ) logger.info(f"CLIP model loaded successfully on {self.device}") except Exception as e: logger.error(f"Failed to load CLIP model: {str(e)}") raise RuntimeError(f"Model loading failed: {str(e)}") def is_supported_format(self, image_url: str) -> bool: """Check if image format is supported by PIL/CLIP""" unsupported_extensions = ['.avif', '.heic', '.heif'] url_lower = image_url.lower() return not any(url_lower.endswith(ext) for ext in unsupported_extensions) def detect_image_format(self, content: bytes) -> str: """Detect actual image format from content""" try: # Check for AVIF signature if content.startswith(b'\\x00\\x00\\x00') and b'ftypavif' in content[:32]: return 'AVIF' # Check for HEIC signature elif content.startswith(b'\\x00\\x00\\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]): return 'HEIC' # Check for WebP elif content.startswith(b'RIFF') and b'WEBP' in content[:12]: return 'WebP' # Check for PNG elif content.startswith(b'\\x89PNG\\r\\n\\x1a\\n'): return 'PNG' # Check for JPEG elif content.startswith(b'\\xff\\xd8\\xff'): return 'JPEG' # Check for GIF elif content.startswith((b'GIF87a', b'GIF89a')): return 'GIF' else: return 'Unknown' except: return 'Unknown' def encode_image(self, image_url: str) -> list: try: logger.info(f"Processing image: {image_url}") # Quick URL-based format check first if not self.is_supported_format(image_url): logger.warning(f"Unsupported format detected from URL: {image_url}") raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)") response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'}) response.raise_for_status() # Detect actual format from content image_format = self.detect_image_format(response.content) logger.info(f"Detected image format: {image_format}") if image_format in ['AVIF', 'HEIC']: logger.warning(f"Unsupported format detected: {image_format} for {image_url}") raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}") try: image = Image.open(io.BytesIO(response.content)) except Exception as e: logger.error(f"PIL cannot open image {image_url}: {str(e)}") if "cannot identify image file" in str(e).lower(): raise HTTPException(status_code=422, detail="Unsupported or corrupted image format") raise if image.mode != 'RGB': logger.info(f"Converting image from {image.mode} to RGB") image = image.convert('RGB') # Resize image if too large to avoid memory issues max_size = 224 # CLIP's expected input size if max(image.size) > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) # Try multiple processor configurations try: # Method 1: Standard CLIP processing inputs = self.processor( images=image, return_tensors="pt", do_rescale=True, do_normalize=True ) except Exception as e1: logger.warning(f"Method 1 failed: {e1}, trying method 2...") try: # Method 2: With padding inputs = self.processor( images=image, return_tensors="pt", padding=True, do_rescale=True, do_normalize=True ) except Exception as e2: logger.warning(f"Method 2 failed: {e2}, trying method 3...") # Method 3: Manual preprocessing inputs = self.processor( images=[image], return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): image_features = self.model.get_image_features(**inputs) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy().flatten().tolist() except Exception as e: logger.error(f"Error encoding image {image_url}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") def encode_text(self, text: str) -> list: try: logger.info(f"Processing text: {text[:50]}...") inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device) with torch.no_grad(): text_features = self.model.get_text_features(**inputs) text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.cpu().numpy().flatten().tolist() except Exception as e: logger.error(f"Error encoding text '{text[:50]}...': {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}") # Initialize service with error handling logger.info("Initializing CLIP service...") try: clip_service = CLIPService() logger.info("CLIP service initialized successfully!") except Exception as e: logger.error(f"Failed to initialize CLIP service: {str(e)}") logger.error(f"Error details: {type(e).__name__}: {str(e)}") clip_service = None class ImageRequest(BaseModel): image_url: str class TextRequest(BaseModel): text: str @app.get("/") async def root(): return { "message": "CLIP Service API", "version": "1.0.0", "model": "clip-vit-large-patch14", "endpoints": ["/encode/image", "/encode/text", "/health"], "status": "ready" if clip_service else "error" } @app.post("/encode/image") async def encode_image(request: ImageRequest): if not clip_service: raise HTTPException(status_code=503, detail="CLIP service not available") embedding = clip_service.encode_image(request.image_url) return {"embedding": embedding, "dimensions": len(embedding)} @app.post("/encode/text") async def encode_text(request: TextRequest): if not clip_service: raise HTTPException(status_code=503, detail="CLIP service not available") embedding = clip_service.encode_text(request.text) return {"embedding": embedding, "dimensions": len(embedding)} @app.get("/health") async def health_check(): if not clip_service: return { "status": "unhealthy", "model": "clip-vit-large-patch14", "error": "Service failed to initialize" } return { "status": "healthy", "model": "clip-vit-large-patch14", "device": clip_service.device, "service": "ready", "cache_dir": cache_dir } if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860 uvicorn.run(app, host="0.0.0.0", port=port)