rmoxon commited on
Commit
c819b55
·
verified ·
1 Parent(s): 58c2e09

Upload 4 files

Browse files
Files changed (4) hide show
  1. app-simple.py +239 -0
  2. app.py +19 -42
  3. requirements-simple.txt +9 -7
  4. requirements.txt +9 -10
app-simple.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ import torch
8
+ from PIL import Image
9
+ import requests
10
+ import numpy as np
11
+ import io
12
+ import logging
13
+
14
+ # Set up cache directories
15
+ cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/code/cache')
16
+ os.makedirs(cache_dir, exist_ok=True)
17
+ os.environ['TRANSFORMERS_CACHE'] = cache_dir
18
+ os.environ['HF_HOME'] = cache_dir
19
+ os.environ['TORCH_HOME'] = cache_dir
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ app = FastAPI(title="CLIP Service", version="1.0.0")
26
+
27
+ class CLIPService:
28
+ def __init__(self):
29
+ logger.info("Loading CLIP model...")
30
+ try:
31
+ # Use CPU for Hugging Face free tier
32
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ logger.info(f"Using device: {self.device}")
34
+
35
+ # Load model with explicit cache directory
36
+ self.model = CLIPModel.from_pretrained(
37
+ "openai/clip-vit-large-patch14",
38
+ cache_dir=cache_dir,
39
+ local_files_only=False
40
+ ).to(self.device)
41
+
42
+ self.processor = CLIPProcessor.from_pretrained(
43
+ "openai/clip-vit-large-patch14",
44
+ cache_dir=cache_dir,
45
+ local_files_only=False
46
+ )
47
+
48
+ logger.info(f"CLIP model loaded successfully on {self.device}")
49
+
50
+ except Exception as e:
51
+ logger.error(f"Failed to load CLIP model: {str(e)}")
52
+ raise RuntimeError(f"Model loading failed: {str(e)}")
53
+
54
+ def is_supported_format(self, image_url: str) -> bool:
55
+ """Check if image format is supported by PIL/CLIP"""
56
+ unsupported_extensions = ['.avif', '.heic', '.heif']
57
+ url_lower = image_url.lower()
58
+ return not any(url_lower.endswith(ext) for ext in unsupported_extensions)
59
+
60
+ def detect_image_format(self, content: bytes) -> str:
61
+ """Detect actual image format from content"""
62
+ try:
63
+ # Check for AVIF signature
64
+ if content.startswith(b'\\x00\\x00\\x00') and b'ftypavif' in content[:32]:
65
+ return 'AVIF'
66
+ # Check for HEIC signature
67
+ elif content.startswith(b'\\x00\\x00\\x00') and b'ftyp' in content[:32] and (b'heic' in content[:32] or b'heix' in content[:32]):
68
+ return 'HEIC'
69
+ # Check for WebP
70
+ elif content.startswith(b'RIFF') and b'WEBP' in content[:12]:
71
+ return 'WebP'
72
+ # Check for PNG
73
+ elif content.startswith(b'\\x89PNG\\r\\n\\x1a\\n'):
74
+ return 'PNG'
75
+ # Check for JPEG
76
+ elif content.startswith(b'\\xff\\xd8\\xff'):
77
+ return 'JPEG'
78
+ # Check for GIF
79
+ elif content.startswith((b'GIF87a', b'GIF89a')):
80
+ return 'GIF'
81
+ else:
82
+ return 'Unknown'
83
+ except:
84
+ return 'Unknown'
85
+
86
+ def encode_image(self, image_url: str) -> list:
87
+ try:
88
+ logger.info(f"Processing image: {image_url}")
89
+
90
+ # Quick URL-based format check first
91
+ if not self.is_supported_format(image_url):
92
+ logger.warning(f"Unsupported format detected from URL: {image_url}")
93
+ raise HTTPException(status_code=422, detail="Unsupported image format (AVIF/HEIC not supported)")
94
+
95
+ response = requests.get(image_url, timeout=30, headers={'User-Agent': 'CLIP-Service/1.0'})
96
+ response.raise_for_status()
97
+
98
+ # Detect actual format from content
99
+ image_format = self.detect_image_format(response.content)
100
+ logger.info(f"Detected image format: {image_format}")
101
+
102
+ if image_format in ['AVIF', 'HEIC']:
103
+ logger.warning(f"Unsupported format detected: {image_format} for {image_url}")
104
+ raise HTTPException(status_code=422, detail=f"Unsupported image format: {image_format}")
105
+
106
+ try:
107
+ image = Image.open(io.BytesIO(response.content))
108
+ except Exception as e:
109
+ logger.error(f"PIL cannot open image {image_url}: {str(e)}")
110
+ if "cannot identify image file" in str(e).lower():
111
+ raise HTTPException(status_code=422, detail="Unsupported or corrupted image format")
112
+ raise
113
+
114
+ if image.mode != 'RGB':
115
+ logger.info(f"Converting image from {image.mode} to RGB")
116
+ image = image.convert('RGB')
117
+
118
+ # Resize image if too large to avoid memory issues
119
+ max_size = 224 # CLIP's expected input size
120
+ if max(image.size) > max_size:
121
+ image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
122
+
123
+ # Try multiple processor configurations
124
+ try:
125
+ # Method 1: Standard CLIP processing
126
+ inputs = self.processor(
127
+ images=image,
128
+ return_tensors="pt",
129
+ do_rescale=True,
130
+ do_normalize=True
131
+ )
132
+ except Exception as e1:
133
+ logger.warning(f"Method 1 failed: {e1}, trying method 2...")
134
+ try:
135
+ # Method 2: With padding
136
+ inputs = self.processor(
137
+ images=image,
138
+ return_tensors="pt",
139
+ padding=True,
140
+ do_rescale=True,
141
+ do_normalize=True
142
+ )
143
+ except Exception as e2:
144
+ logger.warning(f"Method 2 failed: {e2}, trying method 3...")
145
+ # Method 3: Manual preprocessing
146
+ inputs = self.processor(
147
+ images=[image],
148
+ return_tensors="pt"
149
+ )
150
+
151
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
152
+
153
+ with torch.no_grad():
154
+ image_features = self.model.get_image_features(**inputs)
155
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
156
+
157
+ return image_features.cpu().numpy().flatten().tolist()
158
+
159
+ except Exception as e:
160
+ logger.error(f"Error encoding image {image_url}: {str(e)}")
161
+ raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
162
+
163
+ def encode_text(self, text: str) -> list:
164
+ try:
165
+ logger.info(f"Processing text: {text[:50]}...")
166
+ inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
167
+
168
+ with torch.no_grad():
169
+ text_features = self.model.get_text_features(**inputs)
170
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
171
+
172
+ return text_features.cpu().numpy().flatten().tolist()
173
+ except Exception as e:
174
+ logger.error(f"Error encoding text '{text[:50]}...': {str(e)}")
175
+ raise HTTPException(status_code=500, detail=f"Failed to encode text: {str(e)}")
176
+
177
+ # Initialize service with error handling
178
+ logger.info("Initializing CLIP service...")
179
+ try:
180
+ clip_service = CLIPService()
181
+ logger.info("CLIP service initialized successfully!")
182
+ except Exception as e:
183
+ logger.error(f"Failed to initialize CLIP service: {str(e)}")
184
+ logger.error(f"Error details: {type(e).__name__}: {str(e)}")
185
+ clip_service = None
186
+
187
+ class ImageRequest(BaseModel):
188
+ image_url: str
189
+
190
+ class TextRequest(BaseModel):
191
+ text: str
192
+
193
+ @app.get("/")
194
+ async def root():
195
+ return {
196
+ "message": "CLIP Service API",
197
+ "version": "1.0.0",
198
+ "model": "clip-vit-large-patch14",
199
+ "endpoints": ["/encode/image", "/encode/text", "/health"],
200
+ "status": "ready" if clip_service else "error"
201
+ }
202
+
203
+ @app.post("/encode/image")
204
+ async def encode_image(request: ImageRequest):
205
+ if not clip_service:
206
+ raise HTTPException(status_code=503, detail="CLIP service not available")
207
+
208
+ embedding = clip_service.encode_image(request.image_url)
209
+ return {"embedding": embedding, "dimensions": len(embedding)}
210
+
211
+ @app.post("/encode/text")
212
+ async def encode_text(request: TextRequest):
213
+ if not clip_service:
214
+ raise HTTPException(status_code=503, detail="CLIP service not available")
215
+
216
+ embedding = clip_service.encode_text(request.text)
217
+ return {"embedding": embedding, "dimensions": len(embedding)}
218
+
219
+ @app.get("/health")
220
+ async def health_check():
221
+ if not clip_service:
222
+ return {
223
+ "status": "unhealthy",
224
+ "model": "clip-vit-large-patch14",
225
+ "error": "Service failed to initialize"
226
+ }
227
+
228
+ return {
229
+ "status": "healthy",
230
+ "model": "clip-vit-large-patch14",
231
+ "device": clip_service.device,
232
+ "service": "ready",
233
+ "cache_dir": cache_dir
234
+ }
235
+
236
+ if __name__ == "__main__":
237
+ import uvicorn
238
+ port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860
239
+ uvicorn.run(app, host="0.0.0.0", port=port)
app.py CHANGED
@@ -9,13 +9,8 @@ try:
9
  CLAP_AVAILABLE = True
10
  CLAP_METHOD = "transformers"
11
  except ImportError as e1:
12
- try:
13
- import laion_clap
14
- CLAP_AVAILABLE = True
15
- CLAP_METHOD = "laion"
16
- except ImportError as e2:
17
- CLAP_AVAILABLE = False
18
- CLAP_METHOD = None
19
  import torch
20
  from PIL import Image
21
  import requests
@@ -77,33 +72,31 @@ class CLIPService:
77
  def _load_clap_model(self):
78
  """Load CLAP model on demand"""
79
  if not CLAP_AVAILABLE:
80
- raise RuntimeError("CLAP model not available")
81
 
82
  if self.clap_model is None:
83
  logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...")
84
  try:
85
  if CLAP_METHOD == "transformers":
 
86
  self.clap_model = ClapModel.from_pretrained(
87
  "laion/clap-htsat-unfused",
88
  cache_dir=cache_dir,
89
  local_files_only=False
90
  ).to(self.device)
91
 
 
92
  self.clap_processor = ClapProcessor.from_pretrained(
93
  "laion/clap-htsat-unfused",
94
  cache_dir=cache_dir,
95
  local_files_only=False
96
  )
97
 
98
- elif CLAP_METHOD == "laion":
99
- # Use the official LAION CLAP library
100
- self.clap_model = laion_clap.CLAP_Module(enable_fusion=False)
101
- self.clap_model.load_ckpt() # Load the default checkpoint
102
-
103
  logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}")
104
 
105
  except Exception as e:
106
  logger.error(f"Failed to load CLAP model: {str(e)}")
 
107
  raise RuntimeError(f"CLAP model loading failed: {str(e)}")
108
 
109
  def is_supported_format(self, image_url: str) -> bool:
@@ -255,36 +248,20 @@ class CLIPService:
255
  if len(audio_array) > max_length:
256
  audio_array = audio_array[:max_length]
257
 
258
- # Process with CLAP based on method
259
- if CLAP_METHOD == "transformers":
260
- inputs = self.clap_processor(
261
- audios=audio_array,
262
- sampling_rate=48000,
263
- return_tensors="pt"
264
- )
265
-
266
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
267
-
268
- with torch.no_grad():
269
- audio_features = self.clap_model.get_audio_features(**inputs)
270
- audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
271
-
272
- return audio_features.cpu().numpy().flatten().tolist()
273
-
274
- elif CLAP_METHOD == "laion":
275
- # Use LAION CLAP library
276
- with torch.no_grad():
277
- audio_features = self.clap_model.get_audio_embedding_from_data(
278
- x=audio_array,
279
- use_tensor=True
280
- )
281
- # Normalize embedding
282
- audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
283
-
284
- return audio_features.cpu().numpy().flatten().tolist()
285
 
286
- else:
287
- raise RuntimeError(f"Unknown CLAP method: {CLAP_METHOD}")
288
 
289
  finally:
290
  # Clean up temp file
 
9
  CLAP_AVAILABLE = True
10
  CLAP_METHOD = "transformers"
11
  except ImportError as e1:
12
+ CLAP_AVAILABLE = False
13
+ CLAP_METHOD = None
 
 
 
 
 
14
  import torch
15
  from PIL import Image
16
  import requests
 
72
  def _load_clap_model(self):
73
  """Load CLAP model on demand"""
74
  if not CLAP_AVAILABLE:
75
+ raise RuntimeError("CLAP model not available - transformers version may not support CLAP")
76
 
77
  if self.clap_model is None:
78
  logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...")
79
  try:
80
  if CLAP_METHOD == "transformers":
81
+ logger.info("Loading CLAP model from HuggingFace...")
82
  self.clap_model = ClapModel.from_pretrained(
83
  "laion/clap-htsat-unfused",
84
  cache_dir=cache_dir,
85
  local_files_only=False
86
  ).to(self.device)
87
 
88
+ logger.info("Loading CLAP processor...")
89
  self.clap_processor = ClapProcessor.from_pretrained(
90
  "laion/clap-htsat-unfused",
91
  cache_dir=cache_dir,
92
  local_files_only=False
93
  )
94
 
 
 
 
 
 
95
  logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}")
96
 
97
  except Exception as e:
98
  logger.error(f"Failed to load CLAP model: {str(e)}")
99
+ logger.error(f"Error type: {type(e).__name__}")
100
  raise RuntimeError(f"CLAP model loading failed: {str(e)}")
101
 
102
  def is_supported_format(self, image_url: str) -> bool:
 
248
  if len(audio_array) > max_length:
249
  audio_array = audio_array[:max_length]
250
 
251
+ # Process with CLAP using transformers method
252
+ inputs = self.clap_processor(
253
+ audios=audio_array,
254
+ sampling_rate=48000,
255
+ return_tensors="pt"
256
+ )
257
+
258
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
259
+
260
+ with torch.no_grad():
261
+ audio_features = self.clap_model.get_audio_features(**inputs)
262
+ audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ return audio_features.cpu().numpy().flatten().tolist()
 
265
 
266
  finally:
267
  # Clean up temp file
requirements-simple.txt CHANGED
@@ -1,7 +1,9 @@
1
- torch>=2.0.0
2
- transformers>=4.30.0
3
- Pillow>=9.0.0
4
- requests>=2.28.0
5
- fastapi>=0.104.0
6
- uvicorn[standard]>=0.22.0
7
- python-multipart>=0.0.6
 
 
 
1
+ torch>=2.1.0
2
+ transformers==4.30.0
3
+ Pillow==9.5.0
4
+ requests==2.31.0
5
+ fastapi==0.104.1
6
+ uvicorn==0.22.0
7
+ python-multipart==0.0.6
8
+ pydantic==2.5.0
9
+ numpy<2.0.0
requirements.txt CHANGED
@@ -1,13 +1,12 @@
1
- torch==2.0.1
2
- transformers>=4.35.0
3
- Pillow==9.5.0
4
- requests==2.31.0
5
- fastapi==0.104.1
6
- uvicorn==0.22.0
7
- python-multipart==0.0.6
8
- pydantic==2.5.0
9
  numpy<2.0.0
10
  librosa>=0.10.0
11
  soundfile>=0.12.1
12
- datasets>=2.14.0
13
- laion-clap
 
1
+ torch>=2.1.0
2
+ transformers>=4.40.0
3
+ Pillow>=9.5.0
4
+ requests>=2.31.0
5
+ fastapi>=0.104.1
6
+ uvicorn>=0.22.0
7
+ python-multipart>=0.0.6
8
+ pydantic>=2.5.0
9
  numpy<2.0.0
10
  librosa>=0.10.0
11
  soundfile>=0.12.1
12
+ datasets>=2.14.0