strandtest / app.py
rmoxon's picture
Upload 4 files
c819b55 verified
raw
history blame
14.3 kB
import os
import tempfile
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import CLIPProcessor, CLIPModel
try:
from transformers import ClapModel, ClapProcessor
CLAP_AVAILABLE = True
CLAP_METHOD = "transformers"
except ImportError as e1:
CLAP_AVAILABLE = False
CLAP_METHOD = None
import torch
from PIL import Image
import requests
import numpy as np
import io
import logging
import librosa
import soundfile as sf
# Set up cache directories
cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache')
os.makedirs(cache_dir, exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_HOME'] = cache_dir
os.environ['TORCH_HOME'] = cache_dir
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="CLIP Service", version="1.0.0")
# Log CLAP availability after logger is initialized
logger.info(f"CLAP availability: {CLAP_AVAILABLE}, method: {CLAP_METHOD}")
class CLIPService:
def __init__(self):
logger.info("Loading CLIP model...")
self.clap_model = None
self.clap_processor = None
try:
# Use CPU for Hugging Face free tier
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Load CLIP model with explicit cache directory
logger.info("Loading CLIP model from HuggingFace...")
self.clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
logger.info("Loading CLIP processor...")
self.clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=cache_dir,
local_files_only=False
)
logger.info(f"CLIP model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load CLIP model: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
raise RuntimeError(f"CLIP model loading failed: {str(e)}")
def _load_clap_model(self):
"""Load CLAP model on demand"""
if not CLAP_AVAILABLE:
raise RuntimeError("CLAP model not available - transformers version may not support CLAP")
if self.clap_model is None:
logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...")
try:
if CLAP_METHOD == "transformers":
logger.info("Loading CLAP model from HuggingFace...")
self.clap_model = ClapModel.from_pretrained(
"laion/clap-htsat-unfused",
cache_dir=cache_dir,
local_files_only=False
).to(self.device)
logger.info("Loading CLAP processor...")
self.clap_processor = ClapProcessor.from_pretrained(
"laion/clap-htsat-unfused",
cache_dir=cache_dir,
local_files_only=False
)
logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}")
except Exception as e:
logger.error(f"Failed to load CLAP model: {str(e)}")
logger.error(f"Error type: {type(e).__name__}")
raise RuntimeError(f"CLAP model loading failed: {str(e)}")
def is_supported_format(self, image_url: str) -> bool:
"""Check if image format is supported by PIL/CLIP"""
unsupported_extensions = ['.avif', '.heic', '.heif']
url_lower = image_url.lower()
return not any(url_lower.endswith(ext) for ext in unsupported_extensions)
def detect_image_format(self, content: bytes) -> str:
"""Detect actual image format from content"""
try:
# Check for AVIF signature
if content.startswith(b'\x00\x00\x00') and b'ftypavif' in content[:32]:
return 'AVIF'
# Check for HEIC signature
elif content.startswith(b'\x00\x00\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]):
return 'HEIC'
# Check for WebP
elif content.startswith(b'RIFF') and b'WEBP' in content[:12]:
return 'WebP'
# Check for PNG
elif content.startswith(b'\x89PNG\r\n\x1a\n'):
return 'PNG'
# Check for JPEG
elif content.startswith(b'\xff\xd8\xff'):
return 'JPEG'
# Check for GIF
elif content.startswith((b'GIF87a', b'GIF89a')):
return 'GIF'
else:
return 'Unknown'
except:
return 'Unknown'
def encode_image(self, image_url: str) -> list:
try:
logger.info(f"Processing image: {image_url}")
# Quick URL-based format check first
if not self.is_supported_format(image_url):
logger.warning(f"Unsupported format detected from URL: {image_url}")
raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)")
response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'})
response.raise_for_status()
# Detect actual format from content
image_format = self.detect_image_format(response.content)
logger.info(f"Detected image format: {image_format}")
if image_format in ['AVIF', 'HEIC']:
logger.warning(f"Unsupported format detected: {image_format} for {image_url}")
raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}")
try:
image = Image.open(io.BytesIO(response.content))
except Exception as e:
logger.error(f"PIL cannot open image {image_url}: {str(e)}")
if "cannot identify image file" in str(e).lower():
raise HTTPException(status_code=422, detail="Unsupported or corrupted image format")
raise
if image.mode != 'RGB':
logger.info(f"Converting image from {image.mode} to RGB")
image = image.convert('RGB')
# Resize image if too large to avoid memory issues
max_size = 224 # CLIP's expected input size
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Try multiple processor configurations
try:
# Method 1: Standard CLIP processing
inputs = self.clip_processor(
images=image,
return_tensors="pt",
do_rescale=True,
do_normalize=True
)
except Exception as e1:
logger.warning(f"Method 1 failed: {e1}, trying method 2...")
try:
# Method 2: With padding
inputs = self.clip_processor(
images=image,
return_tensors="pt",
padding=True,
do_rescale=True,
do_normalize=True
)
except Exception as e2:
logger.warning(f"Method 2 failed: {e2}, trying method 3...")
# Method 3: Manual preprocessing
inputs = self.clip_processor(
images=[image],
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
image_features = self.clip_model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().flatten().tolist()
except Exception as e:
logger.error(f"Error encoding image {image_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
def encode_text(self, text: str) -> list:
try:
logger.info(f"Processing text: {text[:50]}...")
inputs = self.clip_processor(text=[text], return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
text_features = self.clip_model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().flatten().tolist()
except Exception as e:
logger.error(f"Error encoding text '{text[:50]}...': {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}")
def encode_audio(self, audio_url: str) -> list:
try:
logger.info(f"Processing audio: {audio_url}")
# Load CLAP model on demand
self._load_clap_model()
# Download audio file
response = requests.get(audio_url, timeout=60, headers={'User-Agent': 'CLAP-Service/1.0'})
response.raise_for_status()
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
tmp_file.write(response.content)
tmp_path = tmp_file.name
try:
# Load audio with librosa
# CLAP expects 48kHz sampling rate
audio_array, sample_rate = librosa.load(tmp_path, sr=48000, mono=True)
# Ensure audio is not too long (max 30 seconds for CLAP)
max_length = 30 * 48000 # 30 seconds at 48kHz
if len(audio_array) > max_length:
audio_array = audio_array[:max_length]
# Process with CLAP using transformers method
inputs = self.clap_processor(
audios=audio_array,
sampling_rate=48000,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
audio_features = self.clap_model.get_audio_features(**inputs)
audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
return audio_features.cpu().numpy().flatten().tolist()
finally:
# Clean up temp file
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
logger.error(f"Error encoding audio {audio_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}")
# Initialize service with error handling
logger.info("Initializing CLIP service...")
try:
clip_service = CLIPService()
logger.info("CLIP service initialized successfully!")
except Exception as e:
logger.error(f"Failed to initialize CLIP service: {str(e)}")
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
# For now, we'll let the app start but service calls will fail gracefully
clip_service = None
class ImageRequest(BaseModel):
image_url: str
class TextRequest(BaseModel):
text: str
class AudioRequest(BaseModel):
audio_url: str
@app.get("/")
async def root():
return {
"message": "CLIP Service API",
"version": "1.0.0",
"model": "clip-vit-large-patch14",
"endpoints": ["/encode/image", "/encode/text", "/encode/audio", "/health"],
"status": "ready" if clip_service else "error"
}
@app.post("/encode/image")
async def encode_image(request: ImageRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLIP service not available")
embedding = clip_service.encode_image(request.image_url)
return {"embedding": embedding, "dimensions": len(embedding)}
@app.post("/encode/text")
async def encode_text(request: TextRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLIP service not available")
embedding = clip_service.encode_text(request.text)
return {"embedding": embedding, "dimensions": len(embedding)}
@app.post("/encode/audio")
async def encode_audio(request: AudioRequest):
if not clip_service:
raise HTTPException(status_code=503, detail="CLAP service not available")
if not CLAP_AVAILABLE:
raise HTTPException(status_code=501, detail="CLAP model not available in this transformers version")
embedding = clip_service.encode_audio(request.audio_url)
return {"embedding": embedding, "dimensions": len(embedding)}
@app.get("/health")
async def health_check():
if not clip_service:
return {
"status": "unhealthy",
"model": "clip-vit-large-patch14",
"error": "Service failed to initialize"
}
return {
"status": "healthy",
"models": {
"clip": "clip-vit-large-patch14",
"clap": f"clap-htsat-unfused (lazy loaded, method: {CLAP_METHOD})" if CLAP_AVAILABLE else "not available"
},
"device": clip_service.device,
"service": "ready",
"cache_dir": cache_dir
}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860
uvicorn.run(app, host="0.0.0.0", port=port)