Spaces:
Paused
Paused
import os | |
import tempfile | |
from pathlib import Path | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
from PIL import Image | |
import requests | |
import numpy as np | |
import io | |
import logging | |
# Set up cache directories | |
cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache') | |
os.makedirs(cache_dir, exist_ok=True) | |
os.environ['TRANSFORMERS_CACHE'] = cache_dir | |
os.environ['HF_HOME'] = cache_dir | |
os.environ['TORCH_HOME'] = cache_dir | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="CLIP Service", version="1.0.0") | |
class CLIPService: | |
def __init__(self): | |
logger.info("Loading CLIP model...") | |
try: | |
# Use CPU for Hugging Face free tier | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Load model with explicit cache directory | |
self.model = CLIPModel.from_pretrained( | |
"openai/clip-vit-large-patch14", | |
cache_dir=cache_dir, | |
local_files_only=False | |
).to(self.device) | |
self.processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-large-patch14", | |
cache_dir=cache_dir, | |
local_files_only=False | |
) | |
logger.info(f"CLIP model loaded successfully on {self.device}") | |
except Exception as e: | |
logger.error(f"Failed to load CLIP model: {str(e)}") | |
raise RuntimeError(f"Model loading failed: {str(e)}") | |
def is_supported_format(self, image_url: str) -> bool: | |
"""Check if image format is supported by PIL/CLIP""" | |
unsupported_extensions = ['.avif', '.heic', '.heif'] | |
url_lower = image_url.lower() | |
return not any(url_lower.endswith(ext) for ext in unsupported_extensions) | |
def detect_image_format(self, content: bytes) -> str: | |
"""Detect actual image format from content""" | |
try: | |
# Check for AVIF signature | |
if content.startswith(b'\\x00\\x00\\x00') and b'ftypavif' in content[:32]: | |
return 'AVIF' | |
# Check for HEIC signature | |
elif content.startswith(b'\\x00\\x00\\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]): | |
return 'HEIC' | |
# Check for WebP | |
elif content.startswith(b'RIFF') and b'WEBP' in content[:12]: | |
return 'WebP' | |
# Check for PNG | |
elif content.startswith(b'\\x89PNG\\r\\n\\x1a\\n'): | |
return 'PNG' | |
# Check for JPEG | |
elif content.startswith(b'\\xff\\xd8\\xff'): | |
return 'JPEG' | |
# Check for GIF | |
elif content.startswith((b'GIF87a', b'GIF89a')): | |
return 'GIF' | |
else: | |
return 'Unknown' | |
except: | |
return 'Unknown' | |
def encode_image(self, image_url: str) -> list: | |
try: | |
logger.info(f"Processing image: {image_url}") | |
# Quick URL-based format check first | |
if not self.is_supported_format(image_url): | |
logger.warning(f"Unsupported format detected from URL: {image_url}") | |
raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)") | |
response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'}) | |
response.raise_for_status() | |
# Detect actual format from content | |
image_format = self.detect_image_format(response.content) | |
logger.info(f"Detected image format: {image_format}") | |
if image_format in ['AVIF', 'HEIC']: | |
logger.warning(f"Unsupported format detected: {image_format} for {image_url}") | |
raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}") | |
try: | |
image = Image.open(io.BytesIO(response.content)) | |
except Exception as e: | |
logger.error(f"PIL cannot open image {image_url}: {str(e)}") | |
if "cannot identify image file" in str(e).lower(): | |
raise HTTPException(status_code=422, detail="Unsupported or corrupted image format") | |
raise | |
if image.mode != 'RGB': | |
logger.info(f"Converting image from {image.mode} to RGB") | |
image = image.convert('RGB') | |
# Resize image if too large to avoid memory issues | |
max_size = 224 # CLIP's expected input size | |
if max(image.size) > max_size: | |
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
# Try multiple processor configurations | |
try: | |
# Method 1: Standard CLIP processing | |
inputs = self.processor( | |
images=image, | |
return_tensors="pt", | |
do_rescale=True, | |
do_normalize=True | |
) | |
except Exception as e1: | |
logger.warning(f"Method 1 failed: {e1}, trying method 2...") | |
try: | |
# Method 2: With padding | |
inputs = self.processor( | |
images=image, | |
return_tensors="pt", | |
padding=True, | |
do_rescale=True, | |
do_normalize=True | |
) | |
except Exception as e2: | |
logger.warning(f"Method 2 failed: {e2}, trying method 3...") | |
# Method 3: Manual preprocessing | |
inputs = self.processor( | |
images=[image], | |
return_tensors="pt" | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
image_features = self.model.get_image_features(**inputs) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features.cpu().numpy().flatten().tolist() | |
except Exception as e: | |
logger.error(f"Error encoding image {image_url}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") | |
def encode_text(self, text: str) -> list: | |
try: | |
logger.info(f"Processing text: {text[:50]}...") | |
inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device) | |
with torch.no_grad(): | |
text_features = self.model.get_text_features(**inputs) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
return text_features.cpu().numpy().flatten().tolist() | |
except Exception as e: | |
logger.error(f"Error encoding text '{text[:50]}...': {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}") | |
# Initialize service with error handling | |
logger.info("Initializing CLIP service...") | |
try: | |
clip_service = CLIPService() | |
logger.info("CLIP service initialized successfully!") | |
except Exception as e: | |
logger.error(f"Failed to initialize CLIP service: {str(e)}") | |
logger.error(f"Error details: {type(e).__name__}: {str(e)}") | |
clip_service = None | |
class ImageRequest(BaseModel): | |
image_url: str | |
class TextRequest(BaseModel): | |
text: str | |
async def root(): | |
return { | |
"message": "CLIP Service API", | |
"version": "1.0.0", | |
"model": "clip-vit-large-patch14", | |
"endpoints": ["/encode/image", "/encode/text", "/health"], | |
"status": "ready" if clip_service else "error" | |
} | |
async def encode_image(request: ImageRequest): | |
if not clip_service: | |
raise HTTPException(status_code=503, detail="CLIP service not available") | |
embedding = clip_service.encode_image(request.image_url) | |
return {"embedding": embedding, "dimensions": len(embedding)} | |
async def encode_text(request: TextRequest): | |
if not clip_service: | |
raise HTTPException(status_code=503, detail="CLIP service not available") | |
embedding = clip_service.encode_text(request.text) | |
return {"embedding": embedding, "dimensions": len(embedding)} | |
async def health_check(): | |
if not clip_service: | |
return { | |
"status": "unhealthy", | |
"model": "clip-vit-large-patch14", | |
"error": "Service failed to initialize" | |
} | |
return { | |
"status": "healthy", | |
"model": "clip-vit-large-patch14", | |
"device": clip_service.device, | |
"service": "ready", | |
"cache_dir": cache_dir | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860 | |
uvicorn.run(app, host="0.0.0.0", port=port) |