rmoxon commited on
Commit
eba5056
·
verified ·
1 Parent(s): 9ee1beb

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +90 -30
  2. requirements.txt +3 -2
app.py CHANGED
@@ -3,7 +3,19 @@ import tempfile
3
  from pathlib import Path
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
- from transformers import CLIPProcessor, CLIPModel, ClapModel, ClapProcessor
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  from PIL import Image
9
  import requests
@@ -28,7 +40,7 @@ app = FastAPI(title="CLIP Service", version="1.0.0")
28
 
29
  class CLIPService:
30
  def __init__(self):
31
- logger.info("Loading CLIP and CLAP models...")
32
  try:
33
  # Use CPU for Hugging Face free tier
34
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -47,25 +59,48 @@ class CLIPService:
47
  local_files_only=False
48
  )
49
 
50
- # Load CLAP model for audio processing
51
- self.clap_model = ClapModel.from_pretrained(
52
- "laion/clap-htsat-unfused",
53
- cache_dir=cache_dir,
54
- local_files_only=False
55
- ).to(self.device)
56
 
57
- self.clap_processor = ClapProcessor.from_pretrained(
58
- "laion/clap-htsat-unfused",
59
- cache_dir=cache_dir,
60
- local_files_only=False
61
- )
62
-
63
- logger.info(f"CLIP and CLAP models loaded successfully on {self.device}")
64
 
65
  except Exception as e:
66
- logger.error(f"Failed to load models: {str(e)}")
67
  raise RuntimeError(f"Model loading failed: {str(e)}")
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def is_supported_format(self, image_url: str) -> bool:
70
  """Check if image format is supported by PIL/CLIP"""
71
  unsupported_extensions = ['.avif', '.heic', '.heif']
@@ -193,6 +228,9 @@ class CLIPService:
193
  try:
194
  logger.info(f"Processing audio: {audio_url}")
195
 
 
 
 
196
  # Download audio file
197
  response = requests.get(audio_url, timeout=60, headers={'User-Agent': 'CLAP-Service/1.0'})
198
  response.raise_for_status()
@@ -212,20 +250,36 @@ class CLIPService:
212
  if len(audio_array) > max_length:
213
  audio_array = audio_array[:max_length]
214
 
215
- # Process with CLAP
216
- inputs = self.clap_processor(
217
- audios=audio_array,
218
- sampling_rate=48000,
219
- return_tensors="pt"
220
- )
221
-
222
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
223
-
224
- with torch.no_grad():
225
- audio_features = self.clap_model.get_audio_features(**inputs)
226
- audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- return audio_features.cpu().numpy().flatten().tolist()
 
229
 
230
  finally:
231
  # Clean up temp file
@@ -286,6 +340,9 @@ async def encode_audio(request: AudioRequest):
286
  if not clip_service:
287
  raise HTTPException(status_code=503, detail="CLAP service not available")
288
 
 
 
 
289
  embedding = clip_service.encode_audio(request.audio_url)
290
  return {"embedding": embedding, "dimensions": len(embedding)}
291
 
@@ -300,7 +357,10 @@ async def health_check():
300
 
301
  return {
302
  "status": "healthy",
303
- "models": ["clip-vit-large-patch14", "clap-htsat-unfused"],
 
 
 
304
  "device": clip_service.device,
305
  "service": "ready",
306
  "cache_dir": cache_dir
 
3
  from pathlib import Path
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ try:
8
+ from transformers import ClapModel, ClapProcessor
9
+ CLAP_AVAILABLE = True
10
+ CLAP_METHOD = "transformers"
11
+ except ImportError:
12
+ try:
13
+ import laion_clap
14
+ CLAP_AVAILABLE = True
15
+ CLAP_METHOD = "laion"
16
+ except ImportError:
17
+ CLAP_AVAILABLE = False
18
+ CLAP_METHOD = None
19
  import torch
20
  from PIL import Image
21
  import requests
 
40
 
41
  class CLIPService:
42
  def __init__(self):
43
+ logger.info("Loading CLIP model...")
44
  try:
45
  # Use CPU for Hugging Face free tier
46
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
59
  local_files_only=False
60
  )
61
 
62
+ # Initialize CLAP model placeholders (loaded on demand)
63
+ self.clap_model = None
64
+ self.clap_processor = None
 
 
 
65
 
66
+ logger.info(f"CLIP model loaded successfully on {self.device}")
 
 
 
 
 
 
67
 
68
  except Exception as e:
69
+ logger.error(f"Failed to load CLIP model: {str(e)}")
70
  raise RuntimeError(f"Model loading failed: {str(e)}")
71
 
72
+ def _load_clap_model(self):
73
+ """Load CLAP model on demand"""
74
+ if not CLAP_AVAILABLE:
75
+ raise RuntimeError("CLAP model not available")
76
+
77
+ if self.clap_model is None:
78
+ logger.info(f"Loading CLAP model on demand using {CLAP_METHOD} method...")
79
+ try:
80
+ if CLAP_METHOD == "transformers":
81
+ self.clap_model = ClapModel.from_pretrained(
82
+ "laion/clap-htsat-unfused",
83
+ cache_dir=cache_dir,
84
+ local_files_only=False
85
+ ).to(self.device)
86
+
87
+ self.clap_processor = ClapProcessor.from_pretrained(
88
+ "laion/clap-htsat-unfused",
89
+ cache_dir=cache_dir,
90
+ local_files_only=False
91
+ )
92
+
93
+ elif CLAP_METHOD == "laion":
94
+ # Use the official LAION CLAP library
95
+ self.clap_model = laion_clap.CLAP_Module(enable_fusion=False)
96
+ self.clap_model.load_ckpt() # Load the default checkpoint
97
+
98
+ logger.info(f"CLAP model loaded successfully on {self.device} using {CLAP_METHOD}")
99
+
100
+ except Exception as e:
101
+ logger.error(f"Failed to load CLAP model: {str(e)}")
102
+ raise RuntimeError(f"CLAP model loading failed: {str(e)}")
103
+
104
  def is_supported_format(self, image_url: str) -> bool:
105
  """Check if image format is supported by PIL/CLIP"""
106
  unsupported_extensions = ['.avif', '.heic', '.heif']
 
228
  try:
229
  logger.info(f"Processing audio: {audio_url}")
230
 
231
+ # Load CLAP model on demand
232
+ self._load_clap_model()
233
+
234
  # Download audio file
235
  response = requests.get(audio_url, timeout=60, headers={'User-Agent': 'CLAP-Service/1.0'})
236
  response.raise_for_status()
 
250
  if len(audio_array) > max_length:
251
  audio_array = audio_array[:max_length]
252
 
253
+ # Process with CLAP based on method
254
+ if CLAP_METHOD == "transformers":
255
+ inputs = self.clap_processor(
256
+ audios=audio_array,
257
+ sampling_rate=48000,
258
+ return_tensors="pt"
259
+ )
260
+
261
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
262
+
263
+ with torch.no_grad():
264
+ audio_features = self.clap_model.get_audio_features(**inputs)
265
+ audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
266
+
267
+ return audio_features.cpu().numpy().flatten().tolist()
268
+
269
+ elif CLAP_METHOD == "laion":
270
+ # Use LAION CLAP library
271
+ with torch.no_grad():
272
+ audio_features = self.clap_model.get_audio_embedding_from_data(
273
+ x=audio_array,
274
+ use_tensor=True
275
+ )
276
+ # Normalize embedding
277
+ audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
278
+
279
+ return audio_features.cpu().numpy().flatten().tolist()
280
 
281
+ else:
282
+ raise RuntimeError(f"Unknown CLAP method: {CLAP_METHOD}")
283
 
284
  finally:
285
  # Clean up temp file
 
340
  if not clip_service:
341
  raise HTTPException(status_code=503, detail="CLAP service not available")
342
 
343
+ if not CLAP_AVAILABLE:
344
+ raise HTTPException(status_code=501, detail="CLAP model not available in this transformers version")
345
+
346
  embedding = clip_service.encode_audio(request.audio_url)
347
  return {"embedding": embedding, "dimensions": len(embedding)}
348
 
 
357
 
358
  return {
359
  "status": "healthy",
360
+ "models": {
361
+ "clip": "clip-vit-large-patch14",
362
+ "clap": f"clap-htsat-unfused (lazy loaded, method: {CLAP_METHOD})" if CLAP_AVAILABLE else "not available"
363
+ },
364
  "device": clip_service.device,
365
  "service": "ready",
366
  "cache_dir": cache_dir
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  torch==2.0.1
2
- transformers==4.30.0
3
  Pillow==9.5.0
4
  requests==2.31.0
5
  fastapi==0.104.1
@@ -9,4 +9,5 @@ pydantic==2.5.0
9
  numpy<2.0.0
10
  librosa>=0.10.0
11
  soundfile>=0.12.1
12
- datasets>=2.14.0
 
 
1
  torch==2.0.1
2
+ transformers>=4.35.0
3
  Pillow==9.5.0
4
  requests==2.31.0
5
  fastapi==0.104.1
 
9
  numpy<2.0.0
10
  librosa>=0.10.0
11
  soundfile>=0.12.1
12
+ datasets>=2.14.0
13
+ laion-clap