strandtest / main.py
rmoxon's picture
Upload 5 files
87fa678 verified
raw
history blame
12.4 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import CLIPProcessor, CLIPModel, ClapModel, ClapProcessor
import torch
from PIL import Image
import requests
import numpy as np
import io
import logging
import librosa
import soundfile as sf
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="CLIP Service", version="1.0.0")
class CLIPService:
def __init__(self):
logger.info("Loading CLIP model...")
self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
logger.info("CLIP model loaded successfully")
logger.info("Loading CLAP model for audio...")
self.clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
self.clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
logger.info("CLAP model loaded successfully")
def encode_image(self, image_url: str) -> list:
try:
# Enhanced headers for better compatibility with R2 CDN
headers = {
'User-Agent': 'CLIP-Service/1.0 (Image-Embedding-Service)',
'Accept': 'image/*',
'Cache-Control': 'no-cache'
}
logger.info(f"Fetching image from URL: {image_url}")
response = requests.get(image_url, timeout=30, headers=headers)
response.raise_for_status()
# Log successful fetch with content info
content_type = response.headers.get('content-type', 'unknown')
content_length = len(response.content)
logger.info(f"Successfully fetched image: {content_type}, {content_length} bytes")
image = Image.open(io.BytesIO(response.content))
if image.mode != 'RGB':
logger.info(f"Converting image from {image.mode} to RGB")
image = image.convert('RGB')
# Log image dimensions for debugging
logger.info(f"Processing image: {image.size[0]}x{image.size[1]}")
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
embedding = image_features.numpy().flatten().tolist()
logger.info(f"Generated embedding with {len(embedding)} dimensions")
return embedding
except requests.exceptions.RequestException as e:
logger.error(f"Network error fetching image {image_url}: {str(e)}")
if hasattr(e, 'response') and e.response is not None:
status_code = e.response.status_code
if status_code == 403:
raise HTTPException(status_code=403, detail="Access denied to image URL")
elif status_code == 404:
raise HTTPException(status_code=404, detail="Image not found at URL")
elif status_code >= 500:
raise HTTPException(status_code=502, detail="Image service temporarily unavailable")
raise HTTPException(status_code=500, detail=f"Failed to fetch image: {str(e)}")
except Exception as e:
logger.error(f"Error encoding image {image_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
def encode_text(self, text: str) -> list:
try:
inputs = self.processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.numpy().flatten().tolist()
except Exception as e:
logger.error(f"Error encoding text '{text}': {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}")
def encode_audio(self, audio_url: str) -> list:
try:
# Enhanced headers for audio files with MIME whitelist
headers = {
'User-Agent': 'CLAP-Service/1.0 (Audio-Embedding-Service)',
'Accept': 'audio/mpeg, audio/wav, audio/mp4, audio/ogg, audio/flac',
'Cache-Control': 'no-cache'
}
logger.info(f"Fetching audio from URL: {audio_url}")
# Increase timeout for large files, but add streaming response
response = requests.get(audio_url, timeout=60, headers=headers, stream=True)
response.raise_for_status()
# Check content type before processing
content_type = response.headers.get('content-type', 'unknown')
if not content_type.startswith('audio/'):
raise ValueError(f"Invalid content type: {content_type}. Expected audio/*")
# Check file size before downloading (100MB limit)
content_length = response.headers.get('content-length')
if content_length and int(content_length) > 100 * 1024 * 1024:
raise ValueError(f"Audio file too large: {content_length} bytes. Maximum is 100MB")
# Stream content to BytesIO with size limit
audio_data = io.BytesIO()
total_size = 0
max_size = 100 * 1024 * 1024 # 100MB
for chunk in response.iter_content(chunk_size=8192):
total_size += len(chunk)
if total_size > max_size:
raise ValueError("Audio file too large during download")
audio_data.write(chunk)
audio_data.seek(0)
logger.info(f"Successfully fetched audio: {content_type}, {total_size} bytes")
# Load audio with duration limit (10 minutes = 600 seconds)
MAX_DURATION = 600 # 10 minutes
try:
# First, get duration without loading full audio
duration = librosa.get_duration(path=audio_data)
audio_data.seek(0) # Reset stream
if duration > MAX_DURATION:
raise ValueError(f"Audio duration ({duration:.1f}s) exceeds maximum allowed ({MAX_DURATION}s)")
logger.info(f"Audio duration: {duration:.1f} seconds")
# Load only first 30 seconds for embedding (CLAP works well with shorter clips)
# This reduces memory usage significantly
duration_limit = min(30.0, duration)
# Load audio with librosa (48kHz is CLAP's expected sample rate)
waveform, sample_rate = librosa.load(
audio_data,
sr=48000,
mono=True,
duration=duration_limit,
offset=0.0
)
logger.info(f"Processing audio: {len(waveform)} samples at {sample_rate}Hz ({duration_limit:.1f}s)")
except Exception as e:
logger.error(f"Error loading audio file: {str(e)}")
raise ValueError(f"Failed to load audio file: {str(e)}")
# Process audio through CLAP
inputs = self.clap_processor(audios=waveform, return_tensors="pt", sampling_rate=48000)
with torch.no_grad():
audio_features = self.clap_model.get_audio_features(**inputs)
# Normalize the features
audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
embedding = audio_features.numpy().flatten().tolist()
logger.info(f"Generated audio embedding with {len(embedding)} dimensions")
return embedding
except ValueError as e:
# Handle validation errors (file too large, wrong format, etc.)
logger.error(f"Validation error for audio {audio_url}: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except requests.exceptions.RequestException as e:
logger.error(f"Network error fetching audio {audio_url}: {str(e)}")
if hasattr(e, 'response') and e.response is not None:
status_code = e.response.status_code
if status_code == 403:
raise HTTPException(status_code=403, detail="Access denied to audio URL")
elif status_code == 404:
raise HTTPException(status_code=404, detail="Audio not found at URL")
elif status_code >= 500:
raise HTTPException(status_code=502, detail="Audio service temporarily unavailable")
raise HTTPException(status_code=500, detail=f"Failed to fetch audio: {str(e)}")
except Exception as e:
logger.error(f"Error encoding audio {audio_url}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode audio: {str(e)}")
def encode_text_for_audio(self, text: str) -> list:
"""Encode text for cross-modal audio search"""
try:
inputs = self.clap_processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
text_features = self.clap_model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.numpy().flatten().tolist()
except Exception as e:
logger.error(f"Error encoding text for audio '{text}': {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to encode text for audio: {str(e)}")
# Initialize service
clip_service = CLIPService()
class ImageRequest(BaseModel):
image_url: str
class TextRequest(BaseModel):
text: str
class AudioRequest(BaseModel):
audio_url: str
@app.post("/encode/image")
async def encode_image(request: ImageRequest):
embedding = clip_service.encode_image(request.image_url)
return {"embedding": embedding}
@app.post("/encode/text")
async def encode_text(request: TextRequest):
embedding = clip_service.encode_text(request.text)
return {"embedding": embedding}
@app.post("/encode/audio")
async def encode_audio(request: AudioRequest):
"""Encode audio file to CLAP embedding vector"""
embedding = clip_service.encode_audio(request.audio_url)
return {"embedding": embedding}
@app.post("/encode/text-audio")
async def encode_text_for_audio(request: TextRequest):
"""Encode text for audio similarity search"""
embedding = clip_service.encode_text_for_audio(request.text)
return {"embedding": embedding}
@app.get("/health")
async def health_check():
return {"status": "healthy", "model": "clip-vit-large-patch14"}
@app.post("/validate/image")
async def validate_image_url(request: ImageRequest):
"""Validate that an image URL is accessible without processing it"""
try:
headers = {
'User-Agent': 'CLIP-Service/1.0 (Image-Validation)',
'Accept': 'image/*'
}
response = requests.head(request.image_url, timeout=10, headers=headers)
response.raise_for_status()
content_type = response.headers.get('content-type', 'unknown')
content_length = response.headers.get('content-length', 'unknown')
return {
"accessible": True,
"content_type": content_type,
"content_length": content_length,
"url": request.image_url
}
except Exception as e:
return {
"accessible": False,
"error": str(e),
"url": request.image_url
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)