Spaces:
Paused
Paused
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import CLIPProcessor, CLIPModel, ClapModel, ClapProcessor | |
import torch | |
from PIL import Image | |
import requests | |
import numpy as np | |
import io | |
import logging | |
import librosa | |
import soundfile as sf | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="CLIP Service", version="1.0.0") | |
class CLIPService: | |
def __init__(self): | |
logger.info("Loading CLIP model...") | |
self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
logger.info("CLIP model loaded successfully") | |
logger.info("Loading CLAP model for audio...") | |
self.clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused") | |
self.clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused") | |
logger.info("CLAP model loaded successfully") | |
def encode_image(self, image_url: str) -> list: | |
try: | |
# Enhanced headers for better compatibility with R2 CDN | |
headers = { | |
'User-Agent': 'CLIP-Service/1.0 (Image-Embedding-Service)', | |
'Accept': 'image/*', | |
'Cache-Control': 'no-cache' | |
} | |
logger.info(f"Fetching image from URL: {image_url}") | |
response = requests.get(image_url, timeout=30, headers=headers) | |
response.raise_for_status() | |
# Log successful fetch with content info | |
content_type = response.headers.get('content-type', 'unknown') | |
content_length = len(response.content) | |
logger.info(f"Successfully fetched image: {content_type}, {content_length} bytes") | |
image = Image.open(io.BytesIO(response.content)) | |
if image.mode != 'RGB': | |
logger.info(f"Converting image from {image.mode} to RGB") | |
image = image.convert('RGB') | |
# Log image dimensions for debugging | |
logger.info(f"Processing image: {image.size[0]}x{image.size[1]}") | |
inputs = self.processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
image_features = self.model.get_image_features(**inputs) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
embedding = image_features.numpy().flatten().tolist() | |
logger.info(f"Generated embedding with {len(embedding)} dimensions") | |
return embedding | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Network error fetching image {image_url}: {str(e)}") | |
if hasattr(e, 'response') and e.response is not None: | |
status_code = e.response.status_code | |
if status_code == 403: | |
raise HTTPException(status_code=403, detail="Access denied to image URL") | |
elif status_code == 404: | |
raise HTTPException(status_code=404, detail="Image not found at URL") | |
elif status_code >= 500: | |
raise HTTPException(status_code=502, detail="Image service temporarily unavailable") | |
raise HTTPException(status_code=500, detail=f"Failed to fetch image: {str(e)}") | |
except Exception as e: | |
logger.error(f"Error encoding image {image_url}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") | |
def encode_text(self, text: str) -> list: | |
try: | |
inputs = self.processor(text=[text], return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
text_features = self.model.get_text_features(**inputs) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
return text_features.numpy().flatten().tolist() | |
except Exception as e: | |
logger.error(f"Error encoding text '{text}': {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}") | |
def encode_audio(self, audio_url: str) -> list: | |
try: | |
# Enhanced headers for audio files with MIME whitelist | |
headers = { | |
'User-Agent': 'CLAP-Service/1.0 (Audio-Embedding-Service)', | |
'Accept': 'audio/mpeg, audio/wav, audio/mp4, audio/ogg, audio/flac', | |
'Cache-Control': 'no-cache' | |
} | |
logger.info(f"Fetching audio from URL: {audio_url}") | |
# Increase timeout for large files, but add streaming response | |
response = requests.get(audio_url, timeout=60, headers=headers, stream=True) | |
response.raise_for_status() | |
# Check content type before processing | |
content_type = response.headers.get('content-type', 'unknown') | |
if not content_type.startswith('audio/'): | |
raise ValueError(f"Invalid content type: {content_type}. Expected audio/*") | |
# Check file size before downloading (100MB limit) | |
content_length = response.headers.get('content-length') | |
if content_length and int(content_length) > 100 * 1024 * 1024: | |
raise ValueError(f"Audio file too large: {content_length} bytes. Maximum is 100MB") | |
# Stream content to BytesIO with size limit | |
audio_data = io.BytesIO() | |
total_size = 0 | |
max_size = 100 * 1024 * 1024 # 100MB | |
for chunk in response.iter_content(chunk_size=8192): | |
total_size += len(chunk) | |
if total_size > max_size: | |
raise ValueError("Audio file too large during download") | |
audio_data.write(chunk) | |
audio_data.seek(0) | |
logger.info(f"Successfully fetched audio: {content_type}, {total_size} bytes") | |
# Load audio with duration limit (10 minutes = 600 seconds) | |
MAX_DURATION = 600 # 10 minutes | |
try: | |
# First, get duration without loading full audio | |
duration = librosa.get_duration(path=audio_data) | |
audio_data.seek(0) # Reset stream | |
if duration > MAX_DURATION: | |
raise ValueError(f"Audio duration ({duration:.1f}s) exceeds maximum allowed ({MAX_DURATION}s)") | |
logger.info(f"Audio duration: {duration:.1f} seconds") | |
# Load only first 30 seconds for embedding (CLAP works well with shorter clips) | |
# This reduces memory usage significantly | |
duration_limit = min(30.0, duration) | |
# Load audio with librosa (48kHz is CLAP's expected sample rate) | |
waveform, sample_rate = librosa.load( | |
audio_data, | |
sr=48000, | |
mono=True, | |
duration=duration_limit, | |
offset=0.0 | |
) | |
logger.info(f"Processing audio: {len(waveform)} samples at {sample_rate}Hz ({duration_limit:.1f}s)") | |
except Exception as e: | |
logger.error(f"Error loading audio file: {str(e)}") | |
raise ValueError(f"Failed to load audio file: {str(e)}") | |
# Process audio through CLAP | |
inputs = self.clap_processor(audios=waveform, return_tensors="pt", sampling_rate=48000) | |
with torch.no_grad(): | |
audio_features = self.clap_model.get_audio_features(**inputs) | |
# Normalize the features | |
audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True) | |
embedding = audio_features.numpy().flatten().tolist() | |
logger.info(f"Generated audio embedding with {len(embedding)} dimensions") | |
return embedding | |
except ValueError as e: | |
# Handle validation errors (file too large, wrong format, etc.) | |
logger.error(f"Validation error for audio {audio_url}: {str(e)}") | |
raise HTTPException(status_code=400, detail=str(e)) | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Network error fetching audio {audio_url}: {str(e)}") | |
if hasattr(e, 'response') and e.response is not None: | |
status_code = e.response.status_code | |
if status_code == 403: | |
raise HTTPException(status_code=403, detail="Access denied to audio URL") | |
elif status_code == 404: | |
raise HTTPException(status_code=404, detail="Audio not found at URL") | |
elif status_code >= 500: | |
raise HTTPException(status_code=502, detail="Audio service temporarily unavailable") | |
raise HTTPException(status_code=500, detail=f"Failed to fetch audio: {str(e)}") | |
except Exception as e: | |
logger.error(f"Error encoding audio {audio_url}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}") | |
def encode_text_for_audio(self, text: str) -> list: | |
"""Encode text for cross-modal audio search""" | |
try: | |
inputs = self.clap_processor(text=[text], return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
text_features = self.clap_model.get_text_features(**inputs) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
return text_features.numpy().flatten().tolist() | |
except Exception as e: | |
logger.error(f"Error encoding text for audio '{text}': {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to encode text for audio: {str(e)}") | |
# Initialize service | |
clip_service = CLIPService() | |
class ImageRequest(BaseModel): | |
image_url: str | |
class TextRequest(BaseModel): | |
text: str | |
class AudioRequest(BaseModel): | |
audio_url: str | |
async def encode_image(request: ImageRequest): | |
embedding = clip_service.encode_image(request.image_url) | |
return {"embedding": embedding} | |
async def encode_text(request: TextRequest): | |
embedding = clip_service.encode_text(request.text) | |
return {"embedding": embedding} | |
async def encode_audio(request: AudioRequest): | |
"""Encode audio file to CLAP embedding vector""" | |
embedding = clip_service.encode_audio(request.audio_url) | |
return {"embedding": embedding} | |
async def encode_text_for_audio(request: TextRequest): | |
"""Encode text for audio similarity search""" | |
embedding = clip_service.encode_text_for_audio(request.text) | |
return {"embedding": embedding} | |
async def health_check(): | |
return {"status": "healthy", "model": "clip-vit-large-patch14"} | |
async def validate_image_url(request: ImageRequest): | |
"""Validate that an image URL is accessible without processing it""" | |
try: | |
headers = { | |
'User-Agent': 'CLIP-Service/1.0 (Image-Validation)', | |
'Accept': 'image/*' | |
} | |
response = requests.head(request.image_url, timeout=10, headers=headers) | |
response.raise_for_status() | |
content_type = response.headers.get('content-type', 'unknown') | |
content_length = response.headers.get('content-length', 'unknown') | |
return { | |
"accessible": True, | |
"content_type": content_type, | |
"content_length": content_length, | |
"url": request.image_url | |
} | |
except Exception as e: | |
return { | |
"accessible": False, | |
"error": str(e), | |
"url": request.image_url | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |