strandtest / app-simple.py
rmoxon's picture
Upload 4 files
c819b55 verified
raw
history blame
9.36 kB
import os
import tempfile
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import requests
import numpy as np
import io
import logging
# Set up cache directories
cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache')
os.makedirs(cache_dir, exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_HOME'] = cache_dir
os.environ['TORCH_HOME'] = cache_dir
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="CLIP Service", version="1.0.0")
class CLIPService:
def __init__(self):
logger.info("Loading CLIP model...")
try:
# Use CPU for Hugging Face free tier
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Load model with explicit cache directory
self.model = CLIPModel.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
self.processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=cache_dir,
local_files_only=False
)
logger.info(f"CLIP model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load CLIP model: {str(e)}")
raise RuntimeError(f"Model loading failed: {str(e)}")
def is_supported_format(self, image_url: str) -> bool:
"""Check if image format is supported by PIL/CLIP"""
unsupported_extensions = ['.avif', '.heic', '.heif']
url_lower = image_url.lower()
return not any(url_lower.endswith(ext) for ext in unsupported_extensions)
def detect_image_format(self, content: bytes) -> str:
"""Detect actual image format from content"""
try:
# Check for AVIF signature
if content.startswith(b'\\x00\\x00\\x00') and b'ftypavif' in content[:32]:
return 'AVIF'
# Check for HEIC signature
elif content.startswith(b'\\x00\\x00\\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]):
return 'HEIC'
# Check for WebP
elif content.startswith(b'RIFF') and b'WEBP' in content[:12]:
return 'WebP'
# Check for PNG
elif content.startswith(b'\\x89PNG\\r\\n\\x1a\\n'):
return 'PNG'
# Check for JPEG
elif content.startswith(b'\\xff\\xd8\\xff'):
return 'JPEG'
# Check for GIF
elif content.startswith((b'GIF87a', b'GIF89a')):
return 'GIF'
else:
return 'Unknown'
except:
return 'Unknown'
def encode_image(self, image_url: str) -> list:
try:
logger.info(f"Processing image: {image_url}")
# Quick URL-based format check first
if not self.is_supported_format(image_url):
logger.warning(f"Unsupported format detected from URL: {image_url}")
raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)")
response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'})
response.raise_for_status()
# Detect actual format from content
image_format = self.detect_image_format(response.content)
logger.info(f"Detected image format: {image_format}")
if image_format in ['AVIF', 'HEIC']:
logger.warning(f"Unsupported format detected: {image_format} for {image_url}")
raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}")
try:
image = Image.open(io.BytesIO(response.content))
except Exception as e:
logger.error(f"PIL cannot open image {image_url}: {str(e)}")
if "cannot identify image file" in str(e).lower():
raise HTTPException(status_code=422, detail="Unsupported or corrupted image format")
raise
if image.mode != 'RGB':
logger.info(f"Converting image from {image.mode} to RGB")
image = image.convert('RGB')
# Resize image if too large to avoid memory issues
max_size = 224 # CLIP's expected input size
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Try multiple processor configurations
try:
# Method 1: Standard CLIP processing
inputs = self.processor(
images=image,
return_tensors="pt",
do_rescale=True,
do_normalize=True
)
except Exception as e1:
logger.warning(f"Method 1 failed: {e1}, trying method 2...")
try:
# Method 2: With padding
inputs = self.processor(
images=image,
return_tensors="pt",
padding=True,
do_rescale=True,
do_normalize=True
)
except Exception as e2:
logger.warning(f"Method 2 failed: {e2}, trying method 3...")
# Method 3: Manual preprocessing
inputs = self.processor(
images=[image],
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().flatten().tolist()
except Exception as e:
logger.error(f"Error encoding image {image_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
def encode_text(self, text: str) -> list:
try:
logger.info(f"Processing text: {text[:50]}...")
inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().flatten().tolist()
except Exception as e:
logger.error(f"Error encoding text '{text[:50]}...': {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}")
# Initialize service with error handling
logger.info("Initializing CLIP service...")
try:
clip_service = CLIPService()
logger.info("CLIP service initialized successfully!")
except Exception as e:
logger.error(f"Failed to initialize CLIP service: {str(e)}")
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
clip_service = None
class ImageRequest(BaseModel):
image_url: str
class TextRequest(BaseModel):
text: str
@app.get("/")
async def root():
return {
"message": "CLIP Service API",
"version": "1.0.0",
"model": "clip-vit-large-patch14",
"endpoints": ["/encode/image", "/encode/text", "/health"],
"status": "ready" if clip_service else "error"
}
@app.post("/encode/image")
async def encode_image(request: ImageRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLIP service not available")
embedding = clip_service.encode_image(request.image_url)
return {"embedding": embedding, "dimensions": len(embedding)}
@app.post("/encode/text")
async def encode_text(request: TextRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLIP service not available")
embedding = clip_service.encode_text(request.text)
return {"embedding": embedding, "dimensions": len(embedding)}
@app.get("/health")
async def health_check():
if not clip_service:
return {
"status": "unhealthy",
"model": "clip-vit-large-patch14",
"error": "Service failed to initialize"
}
return {
"status": "healthy",
"model": "clip-vit-large-patch14",
"device": clip_service.device,
"service": "ready",
"cache_dir": cache_dir
}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860
uvicorn.run(app, host="0.0.0.0", port=port)