rifatramadhani commited on
Commit
231d431
·
0 Parent(s):
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. Dockerfile +38 -0
  3. README.md +10 -0
  4. app.py +572 -0
  5. models_config.py +60 -0
  6. requirements.txt +10 -0
  7. static/index.html +142 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13
2
+
3
+ # Install basic tools: wget, curl, unzip
4
+ RUN apt-get update && \
5
+ apt-get install -y wget curl unzip && \
6
+ rm -rf /var/lib/apt/lists/*
7
+
8
+ RUN useradd -m -u 1000 user
9
+ USER user
10
+ ENV PATH="/home/user/.local/bin:$PATH"
11
+
12
+ WORKDIR /app
13
+
14
+ COPY --chown=user ./requirements.txt requirements.txt
15
+
16
+ RUN mkdir -p /home/user/.cache/pip
17
+
18
+ RUN chown user:user /home/user/.cache/pip
19
+ RUN --mount=type=cache,target=/home/user/.cache/pip pip install --upgrade -r requirements.txt
20
+
21
+ COPY --chown=user . /app
22
+
23
+
24
+ ENV APP_PORT=7860
25
+ ENV APP_HOST="0.0.0.0"
26
+ ENV ENVIRONMENT="production"
27
+ ENV TOKENIZERS_PARALLELISM=false
28
+
29
+ ENV DEFAULT_MODEL="text-embedding-3-large"
30
+ ENV WARMUP_ENABLED=true
31
+ ENV CUDA_CACHE_CLEAR_ENABLED=true
32
+ ENV EMBEDDING_BATCH_SIZE=8
33
+ ENV EMBEDDINGS_CACHE_ENABLED=true
34
+ ENV EMBEDDINGS_CACHE_MAXSIZE=2048
35
+ ENV REPORT_CACHED_TOKENS=false
36
+
37
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
38
+ # CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Wip Test
3
+ emoji: 🐨
4
+ colorFrom: pink
5
+ colorTo: gray
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ from typing import List, Union, Tuple
5
+ from cachetools import LRUCache
6
+ import hashlib
7
+ import asyncio
8
+ from functools import lru_cache
9
+ from contextlib import asynccontextmanager
10
+
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=(
16
+ logging.DEBUG if os.environ.get("ENVIRONMENT") != "production" else logging.INFO
17
+ ),
18
+ format="%(asctime)s - %(levelname)s - %(message)s",
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ from fastapi import FastAPI, HTTPException, Request, Depends
23
+ from fastapi.staticfiles import StaticFiles
24
+ from fastapi.responses import FileResponse, JSONResponse
25
+ from fastapi.middleware.cors import CORSMiddleware
26
+ from fastapi.exceptions import RequestValidationError
27
+ from pydantic import BaseModel, Field, field_validator, ConfigDict # Import ConfigDict
28
+ from pydantic_settings import BaseSettings
29
+ from transformers import AutoModel, AutoTokenizer
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import uvicorn
33
+ from starlette import status
34
+
35
+
36
+ from models_config import MODELS, get_model_config, CANONICAL_MODELS, MODEL_ALIASES
37
+
38
+
39
+ # --- Configuration Management ---
40
+ class AppSettings(BaseSettings):
41
+ """
42
+ Application settings loaded from environment variables.
43
+ """
44
+ cuda_cache_clear_enabled: bool = Field(
45
+ True, json_schema_extra={"env": "CUDA_CACHE_CLEAR_ENABLED"}, description="Enable CUDA cache clearing after each batch."
46
+ )
47
+ default_model: str = Field(
48
+ "text-embedding-3-large", json_schema_extra={"env": "DEFAULT_MODEL"}, description="Default embedding model to use."
49
+ )
50
+ warmup_enabled: bool = Field(
51
+ True, json_schema_extra={"env": "WARMUP_ENABLED"}, description="Enable model warmup on startup."
52
+ )
53
+ app_port: int = Field(
54
+ 8000, json_schema_extra={"env": "APP_PORT"}, description="Port for the FastAPI application."
55
+ )
56
+ app_host: str = Field(
57
+ "0.0.0.0", json_schema_extra={"env": "APP_HOST"}, description="Host for the FastAPI application."
58
+ )
59
+ embedding_batch_size: int = Field(
60
+ 8, json_schema_extra={"env": "EMBEDDING_BATCH_SIZE"}, description="Batch size for embedding generation."
61
+ )
62
+ embeddings_cache_enabled: bool = Field(
63
+ True, json_schema_extra={"env": "EMBEDDINGS_CACHE_ENABLED"}, description="Enable in-memory embeddings cache."
64
+ )
65
+ report_cached_tokens: bool = Field(
66
+ False, json_schema_extra={"env": "REPORT_CACHED_TOKENS"}, description="Report token count for cached embeddings."
67
+ )
68
+ embeddings_cache_maxsize: int = Field(
69
+ 2048, json_schema_extra={"env": "EMBEDDINGS_CACHE_MAXSIZE"}, description="Maximum size of the embeddings cache."
70
+ )
71
+ environment: str = Field(
72
+ "development", json_schema_extra={"env": "ENVIRONMENT"}, description="Application environment (e.g., 'production', 'development')."
73
+ )
74
+
75
+ model_config = ConfigDict(env_file=".env") # Use ConfigDict instead of class Config
76
+
77
+
78
+ @lru_cache() # Cache the settings instance for performance
79
+ def get_app_settings():
80
+ return AppSettings()
81
+
82
+ # Set up device configuration
83
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ logger.info(f"Using device: {DEVICE}")
85
+
86
+ # Initialize global embeddings cache (size determined by settings)
87
+ embeddings_cache = LRUCache(maxsize=0) # Will be updated on startup based on settings
88
+
89
+ # --- Lifespan Event Handler ---
90
+ @asynccontextmanager
91
+ async def lifespan(app: FastAPI):
92
+ """
93
+ Handles application startup and shutdown events.
94
+ Initializes the embeddings cache and warms up the default model.
95
+ """
96
+ settings = get_app_settings() # Directly get settings here
97
+ global embeddings_cache
98
+ embeddings_cache = LRUCache(maxsize=settings.embeddings_cache_maxsize)
99
+ logger.info(f"Embeddings cache initialized with max size: {settings.embeddings_cache_maxsize}")
100
+
101
+ default_model = settings.default_model
102
+ if default_model not in MODELS:
103
+ logger.error(f"Default model '{default_model}' is not configured in MODELS.")
104
+ raise ValueError(
105
+ f"Default model '{default_model}' is not configured in MODELS."
106
+ )
107
+ if settings.warmup_enabled:
108
+ logger.info(f"Warming up default model: {default_model}...")
109
+ try:
110
+ # Pass settings to get_embeddings_batch
111
+ await get_embeddings_batch(["warmup"], default_model, settings)
112
+ logger.info("Model warmup complete.")
113
+ except Exception as e:
114
+ logger.error(f"Model warmup failed for {default_model}: {e}", exc_info=True)
115
+
116
+ yield # Application starts here
117
+
118
+ # Clean up code (if any) goes here when application shuts down
119
+ logger.info("Application shutdown.")
120
+
121
+
122
+ app = FastAPI(
123
+ title="Embedding API",
124
+ description="API for generating embeddings using a transformer model.",
125
+ version="0.1.0",
126
+ lifespan=lifespan # Assign the lifespan context manager
127
+ )
128
+
129
+ # Add CORS middleware
130
+ app.add_middleware(
131
+ CORSMiddleware,
132
+ allow_origins=["*"], # Allows all origins
133
+ allow_credentials=True,
134
+ allow_methods=["*"], # Allows all methods
135
+ allow_headers=["*"], # Allows all headers
136
+ )
137
+
138
+ # Mount the static directory to serve index.html and other static files.
139
+ app.mount("/static", StaticFiles(directory="static"), name="static")
140
+
141
+ # Initialize model cache
142
+ # to avoid reloading models on every request
143
+ model_cache = {}
144
+ tokenizer_cache = {}
145
+
146
+ # New: Initialize global dictionary for model loading locks
147
+ model_load_locks = {}
148
+
149
+
150
+ async def load_model(model_name: str):
151
+ """
152
+ Load model and tokenizer if not already loaded, with asynchronous locking.
153
+
154
+ Args:
155
+ model_name (str): The name of the model to load.
156
+
157
+ Returns:
158
+ tuple: A tuple containing the loaded model and tokenizer.
159
+ """
160
+ config = get_model_config(model_name)
161
+ canonical_hf_model_name = config["name"]
162
+
163
+ async with model_load_locks.setdefault(canonical_hf_model_name, asyncio.Lock()):
164
+ if canonical_hf_model_name not in model_cache:
165
+ logger.info(f"Loading model: {canonical_hf_model_name}")
166
+ model_path = config["name"]
167
+ trust_remote = config.get("requires_remote_code", False)
168
+
169
+ model_cache[canonical_hf_model_name] = AutoModel.from_pretrained(
170
+ model_path, trust_remote_code=trust_remote
171
+ ).to(DEVICE)
172
+ model_cache[canonical_hf_model_name].eval()
173
+
174
+ tokenizer_cache[canonical_hf_model_name] = AutoTokenizer.from_pretrained(model_path)
175
+ logger.info(f"Model loaded: {canonical_hf_model_name}")
176
+ return model_cache[canonical_hf_model_name], tokenizer_cache[canonical_hf_model_name]
177
+
178
+
179
+ class EmbeddingRequest(BaseModel):
180
+ """
181
+ Represents a request for generating embeddings.
182
+
183
+ Attributes:
184
+ input (Union[str, List[str]]): The input text to embed, can be a single string or a list of strings.
185
+ model (str): The name of the model to use for embedding.
186
+ encoding_format (str): The format of the embeddings. Currently only 'float' is supported.
187
+ """
188
+
189
+ input: Union[str, List[str]] = Field(
190
+ ...,
191
+ description="The input text to embed, can be a single string or a list of strings.",
192
+ json_schema_extra={"example": "This is an example sentence."},
193
+ )
194
+ model: str = Field(
195
+ "text-embedding-3-large",
196
+ description="The name of the model to use for embedding. Supports both original model names and OpenAI-compatible names.",
197
+ json_schema_extra={"example": "text-embedding-3-large"},
198
+ )
199
+ encoding_format: str = Field(
200
+ "float",
201
+ description="The format of the embeddings. Currently only 'float' is supported.",
202
+ json_schema_extra={"example": "float"},
203
+ )
204
+
205
+ @field_validator('model')
206
+ @classmethod
207
+ def validate_model(cls, value: str) -> str:
208
+ if value not in MODELS:
209
+ valid_models = list(CANONICAL_MODELS.keys()) + list(MODEL_ALIASES.keys())
210
+ raise ValueError(f"Model must be one of: {', '.join(sorted(valid_models))}")
211
+ return value
212
+
213
+ @field_validator('encoding_format')
214
+ @classmethod
215
+ def validate_encoding_format(cls, value: str) -> str:
216
+ if value != "float":
217
+ raise ValueError("Only 'float' encoding format is supported")
218
+ return value
219
+
220
+
221
+ class EmbeddingObject(BaseModel):
222
+ """
223
+ Represents an embedding object.
224
+
225
+ Attributes:
226
+ object (str): The type of object, which is "embedding".
227
+ embedding (List[float]): The embedding vector.
228
+ index (int): The index of the embedding.
229
+ """
230
+
231
+ object: str = "embedding"
232
+ embedding: List[float]
233
+ index: int
234
+
235
+
236
+ class EmbeddingResponse(BaseModel):
237
+ """
238
+ Represents the response containing a list of embedding objects.
239
+ """
240
+
241
+ data: List[EmbeddingObject]
242
+ model: str
243
+ object: str = "list"
244
+ usage: dict
245
+
246
+
247
+ class ModelObject(BaseModel):
248
+ """
249
+ Represents a single model object in the list of models.
250
+ """
251
+
252
+ id: str
253
+ object: str = "model"
254
+ created: int
255
+ owned_by: str
256
+
257
+
258
+ class ListModelsResponse(BaseModel):
259
+ """
260
+ Represents the response containing a list of available models.
261
+ """
262
+
263
+ data: List[ModelObject]
264
+ object: str = "list"
265
+
266
+
267
+ # --- Helper functions for get_embeddings_batch refactoring ---
268
+
269
+ def _process_texts_for_cache_and_batching(
270
+ texts: List[str],
271
+ model_config: dict,
272
+ settings: AppSettings
273
+ ) -> Tuple[List[torch.Tensor], int, List[str], List[int]]:
274
+ """
275
+ Checks cache for each text and prepares texts for model processing.
276
+ Returns cached embeddings, total cached tokens, texts to process, and their original indices.
277
+ """
278
+ final_ordered_embeddings = [None] * len(texts)
279
+ total_prompt_tokens = 0
280
+ texts_to_process_in_model = []
281
+ original_indices_for_model_output = []
282
+
283
+ canonical_hf_model_name = model_config["name"]
284
+
285
+ for i, text in enumerate(texts):
286
+ text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
287
+ cache_key = (text_hash, canonical_hf_model_name)
288
+
289
+ if settings.embeddings_cache_enabled and cache_key in embeddings_cache:
290
+ cached_embedding, cached_tokens = embeddings_cache[cache_key]
291
+ final_ordered_embeddings[i] = cached_embedding.unsqueeze(0)
292
+ if settings.report_cached_tokens:
293
+ total_prompt_tokens += cached_tokens
294
+ logger.debug(f"Cache hit for text at index {i}")
295
+ else:
296
+ texts_to_process_in_model.append(text)
297
+ original_indices_for_model_output.append(i)
298
+ logger.debug(f"Cache miss for text at index {i}")
299
+ return final_ordered_embeddings, total_prompt_tokens, texts_to_process_in_model, original_indices_for_model_output
300
+
301
+ def _apply_instruction_prefix(texts: List[str], model_config: dict) -> List[str]:
302
+ """
303
+ Applies instruction prefixes to texts if required by the model configuration.
304
+ """
305
+ if model_config.get("instruction_prefix_required", False):
306
+ processed_texts = []
307
+ default_prefix = model_config.get("default_instruction_prefix", "")
308
+ known_prefixes = model_config.get("known_instruction_prefixes", [])
309
+ for text in texts:
310
+ if not any(text.startswith(prefix) for prefix in known_prefixes):
311
+ processed_texts.append(f"{default_prefix}{text}")
312
+ else:
313
+ processed_texts.append(text)
314
+ return processed_texts
315
+ return texts
316
+
317
+ def _perform_model_inference(
318
+ texts_to_tokenize: List[str],
319
+ model,
320
+ tokenizer,
321
+ model_max_tokens: int,
322
+ model_dimension: int,
323
+ settings: AppSettings
324
+ ) -> Tuple[torch.Tensor, List[int], int]:
325
+ """
326
+ Performs model inference for a batch of texts and returns embeddings,
327
+ individual token counts, and total prompt tokens for the batch.
328
+ Handles CUDA Out of Memory errors.
329
+ """
330
+ try:
331
+ batch_dict = tokenizer(
332
+ texts_to_tokenize,
333
+ max_length=model_max_tokens,
334
+ padding=True,
335
+ truncation=True,
336
+ return_tensors="pt",
337
+ )
338
+
339
+ individual_tokens_in_batch = [
340
+ int(torch.sum(mask).item()) for mask in batch_dict["attention_mask"]
341
+ ]
342
+
343
+ prompt_tokens_current_batch = int(torch.sum(batch_dict["attention_mask"]).item())
344
+
345
+ batch_dict = {k: v.to(DEVICE) for k, v in batch_dict.items()}
346
+
347
+ with torch.no_grad():
348
+ outputs = model(**batch_dict)
349
+
350
+ embeddings = outputs.last_hidden_state[:, 0]
351
+ embeddings = embeddings[:, :model_dimension]
352
+ embeddings = F.normalize(embeddings, p=2, dim=1)
353
+
354
+ return embeddings, individual_tokens_in_batch, prompt_tokens_current_batch
355
+ except torch.cuda.OutOfMemoryError as e:
356
+ logger.error(
357
+ f"CUDA Out of Memory Error during embedding generation: {e}. "
358
+ "Consider reducing EMBEDDING_BATCH_SIZE or using a smaller model.",
359
+ exc_info=True
360
+ )
361
+ raise HTTPException(
362
+ status_code=status.HTTP_507_INSUFFICIENT_STORAGE,
363
+ detail=f"GPU out of memory: {e}. Please try with a smaller batch size or a different model."
364
+ )
365
+ except Exception as e:
366
+ logger.error(
367
+ f"An unexpected error occurred during batch embedding generation: {e}", exc_info=True
368
+ )
369
+ raise HTTPException(
370
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
371
+ detail=f"Internal server error during embedding generation: {str(e)}"
372
+ )
373
+ finally:
374
+ if settings.cuda_cache_clear_enabled and torch.cuda.is_available():
375
+ torch.cuda.empty_cache()
376
+ logger.debug("CUDA cache cleared after processing chunk.")
377
+
378
+ def _store_embeddings_in_cache(
379
+ embeddings: torch.Tensor,
380
+ individual_tokens_in_batch: List[int],
381
+ batch_original_indices: List[int],
382
+ texts: List[str],
383
+ model_config: dict,
384
+ final_ordered_embeddings: List[Union[torch.Tensor, None]],
385
+ settings: AppSettings
386
+ ):
387
+ """
388
+ Stores newly generated embeddings in the cache and updates the final ordered embeddings list.
389
+ """
390
+ canonical_hf_model_name = model_config["name"]
391
+ for j, original_idx in enumerate(batch_original_indices):
392
+ current_text = texts[original_idx]
393
+ current_embedding = embeddings[j].cpu()
394
+ current_tokens = individual_tokens_in_batch[j]
395
+
396
+ current_text_hash = hashlib.sha256(current_text.encode('utf-8')).hexdigest()
397
+ if settings.embeddings_cache_enabled:
398
+ embeddings_cache[(current_text_hash, canonical_hf_model_name)] = (current_embedding, current_tokens)
399
+ final_ordered_embeddings[original_idx] = current_embedding.unsqueeze(0)
400
+
401
+
402
+ async def get_embeddings_batch(
403
+ texts: List[str],
404
+ model_name: str,
405
+ settings: AppSettings = Depends(get_app_settings)
406
+ ) -> Tuple[torch.Tensor, int]:
407
+ """
408
+ Generates embeddings for a batch of texts using the specified model.
409
+ Handles potential CUDA out of memory errors by processing texts in chunks.
410
+ Includes an in-memory cache for individual text-model pairs.
411
+
412
+ Args:
413
+ texts (List[str]): The list of input texts to embed.
414
+ model_name (str): The name of the model to use.
415
+ settings (AppSettings): Application settings injected via FastAPI's Depends.
416
+ """
417
+ config = get_model_config(model_name)
418
+ model, tokenizer = await load_model(model_name)
419
+
420
+ model_max_tokens = config.get("max_tokens", 8192)
421
+ model_dimension = config["dimension"]
422
+ max_batch_size = settings.embedding_batch_size
423
+
424
+ final_ordered_embeddings, total_prompt_tokens, texts_to_process_in_model, original_indices_for_model_output = \
425
+ _process_texts_for_cache_and_batching(texts, config, settings)
426
+
427
+ if texts_to_process_in_model:
428
+ for i in range(0, len(texts_to_process_in_model), max_batch_size):
429
+ batch_texts = texts_to_process_in_model[i : i + max_batch_size]
430
+ batch_original_indices = original_indices_for_model_output[i : i + max_batch_size]
431
+
432
+ texts_to_tokenize = _apply_instruction_prefix(batch_texts, config)
433
+
434
+ embeddings, individual_tokens_in_batch, prompt_tokens_current_batch = \
435
+ _perform_model_inference(texts_to_tokenize, model, tokenizer, model_max_tokens, model_dimension, settings)
436
+
437
+ total_prompt_tokens += prompt_tokens_current_batch
438
+
439
+ _store_embeddings_in_cache(
440
+ embeddings,
441
+ individual_tokens_in_batch,
442
+ batch_original_indices,
443
+ texts,
444
+ config,
445
+ final_ordered_embeddings,
446
+ settings
447
+ )
448
+
449
+ final_embeddings_tensor = torch.cat([e for e in final_ordered_embeddings if e is not None], dim=0)
450
+ return final_embeddings_tensor, total_prompt_tokens
451
+
452
+
453
+ @app.get("/", response_class=FileResponse)
454
+ async def read_root():
455
+ """
456
+ Serve the static index.html file at the root route.
457
+ """
458
+ return FileResponse("static/index.html")
459
+
460
+
461
+ @app.get("/v1/models", response_model=ListModelsResponse)
462
+ async def list_models():
463
+ """
464
+ Lists the available embedding models.
465
+ Returns:
466
+ ListModelsResponse: The response containing a list of model objects.
467
+ """
468
+ model_list = []
469
+ current_time = int(time.time())
470
+ for model_name in MODELS.keys():
471
+ model_list.append(
472
+ ModelObject(
473
+ id=model_name,
474
+ created=current_time,
475
+ owned_by="local",
476
+ )
477
+ )
478
+
479
+ return ListModelsResponse(data=model_list)
480
+
481
+
482
+ @app.get("/v1/models/{model_id}", response_model=ModelObject)
483
+ async def get_model(model_id: str):
484
+ """
485
+ Retrieves information about a specific embedding model.
486
+ Args:
487
+ model_id (str): The ID of the model to retrieve.
488
+ """
489
+ if model_id in MODELS:
490
+ current_time = int(time.time())
491
+ return ModelObject(
492
+ id=model_id,
493
+ created=current_time,
494
+ owned_by="local",
495
+ )
496
+ else:
497
+ raise HTTPException(status_code=404, detail="Model not found")
498
+
499
+
500
+ @app.post(
501
+ "/api/embed", response_model=EmbeddingResponse
502
+ )
503
+ @app.post(
504
+ "/v1/embeddings", response_model=EmbeddingResponse
505
+ )
506
+ async def create_embeddings(request: EmbeddingRequest, settings: AppSettings = Depends(get_app_settings)):
507
+ """
508
+ Generates embeddings for the given input text(s) using batch processing.
509
+ Compatible with OpenAI's Embeddings API format.
510
+ The input can be a single string or a list of strings.
511
+ Returns a list of embedding objects, each containing the embedding vector.
512
+ """
513
+ try:
514
+ start_time = time.time()
515
+
516
+ if isinstance(request.input, str):
517
+ texts = [request.input]
518
+ else:
519
+ texts = request.input
520
+
521
+ if not texts:
522
+ return EmbeddingResponse(
523
+ data=[],
524
+ model=request.model,
525
+ object="list",
526
+ usage={"prompt_tokens": 0, "total_tokens": 0},
527
+ )
528
+
529
+ embeddings_tensor, total_tokens = await get_embeddings_batch(texts, request.model, settings)
530
+
531
+ data = [
532
+ EmbeddingObject(embedding=embeddings_tensor[i].tolist(), index=i)
533
+ for i in range(len(texts))
534
+ ]
535
+
536
+ usage = {
537
+ "prompt_tokens": total_tokens,
538
+ "total_tokens": total_tokens,
539
+ }
540
+
541
+ end_time = time.time()
542
+ processing_time = end_time - start_time
543
+
544
+ if settings.environment != "production":
545
+ logger.debug(
546
+ f"Processed {len(texts)} inputs in {processing_time:.4f} seconds. "
547
+ f"Model: {request.model}. Tokens: {total_tokens}."
548
+ )
549
+
550
+ return EmbeddingResponse(
551
+ data=data, model=request.model, object="list", usage=usage
552
+ )
553
+
554
+ except ValueError as e:
555
+ logger.error(f"Validation error in /v1/embeddings: {e}", exc_info=True)
556
+ raise HTTPException(status_code=422, detail=str(e))
557
+ except HTTPException as e:
558
+ logger.error(f"HTTPException in /v1/embeddings: {e.detail}", exc_info=True)
559
+ raise e
560
+ except Exception as e:
561
+ logger.error(f"Unhandled error in /v1/embeddings: {e}", exc_info=True)
562
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
563
+
564
+
565
+ @app.exception_handler(RequestValidationError)
566
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
567
+ logger.error(f"Validation error for request to {request.url}: {exc.errors()}")
568
+ raise HTTPException(status_code=422, detail=str(exc.errors()))
569
+
570
+
571
+ if __name__ == "__main__":
572
+ uvicorn.run(app, host=get_app_settings().app_host, port=get_app_settings().app_port)
models_config.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models_config.py
2
+
3
+ CANONICAL_MODELS = {
4
+ "all-MiniLM-L6-v2": {
5
+ "name": "sentence-transformers/all-MiniLM-L6-v2",
6
+ "dimension": 384,
7
+ "requires_remote_code": False,
8
+ "max_tokens": 512,
9
+ },
10
+ "gte-multilingual-base": {
11
+ "name": "Alibaba-NLP/gte-multilingual-base",
12
+ "dimension": 768,
13
+ "requires_remote_code": True,
14
+ "max_tokens": 8192,
15
+ },
16
+ "nomic-embed-text-v1.5": {
17
+ "name": "nomic-ai/nomic-embed-text-v1.5",
18
+ "dimension": 768,
19
+ "requires_remote_code": True,
20
+ "max_tokens": 8192,
21
+ "instruction_prefix_required": True,
22
+ "default_instruction_prefix": "search_document:",
23
+ "known_instruction_prefixes": [
24
+ "search_document:",
25
+ "search_query:",
26
+ "clustering:",
27
+ "classification:",
28
+ ],
29
+ },
30
+ "all-mpnet-base-v2": {
31
+ "name": "sentence-transformers/all-mpnet-base-v2",
32
+ "dimension": 768,
33
+ "requires_remote_code": False,
34
+ "max_tokens": 384,
35
+ },
36
+ }
37
+
38
+ # Mapping of aliases to their canonical model names
39
+ MODEL_ALIASES = {
40
+ "all-minilm": "all-MiniLM-L6-v2",
41
+ "text-embedding-3-small": "all-MiniLM-L6-v2",
42
+ "text-embedding-3-large": "gte-multilingual-base",
43
+ "nomic-embed-text": "nomic-embed-text-v1.5",
44
+ }
45
+
46
+ # This global MODELS dictionary will be used for listing available models and validation.
47
+ # It combines canonical names and aliases for easy lookup.
48
+ MODELS = {**CANONICAL_MODELS, **{alias: CANONICAL_MODELS[canonical] for alias, canonical in MODEL_ALIASES.items()}}
49
+
50
+ def get_model_config(requested_model_name: str) -> dict:
51
+ """
52
+ Resolves a requested model name (which might be an alias) to its canonical
53
+ configuration. Raises ValueError if the model is not found.
54
+ """
55
+ canonical_name = MODEL_ALIASES.get(requested_model_name, requested_model_name)
56
+
57
+ if canonical_name not in CANONICAL_MODELS:
58
+ raise ValueError(f"Model '{requested_model_name}' (canonical: '{canonical_name}') is not a recognized model.")
59
+
60
+ return CANONICAL_MODELS[canonical_name]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.1
2
+ cachetools==6.0.0
3
+ fastapi==0.115.12
4
+ httpx==0.28.1
5
+ pydantic==2.11.4
6
+ pydantic-settings==2.9.1
7
+ pytest==8.3.5
8
+ torch==2.7.0
9
+ transformers==4.51.3
10
+ uvicorn==0.34.2
static/index.html ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Embedding Demo</title>
7
+ <!-- Include Tailwind CSS via CDN -->
8
+ <script src="https://cdn.tailwindcss.com"></script>
9
+ <!-- Include Heroicons for the copy icon -->
10
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
11
+ </head>
12
+ <body class="bg-gray-100 flex items-center justify-center min-h-screen">
13
+ <div class="bg-white p-8 rounded-lg shadow-md w-full max-w-md">
14
+ <h1 class="text-2xl font-bold mb-6 text-center">Embedding Demo</h1>
15
+ <div class="mb-4">
16
+ <label for="inputText" class="block text-gray-700 text-sm font-bold mb-2">Enter Text:</label>
17
+ <textarea id="inputText" class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline" rows="4" placeholder="Enter text to embed..."></textarea>
18
+ </div>
19
+ <div class="mb-6">
20
+ <label for="modelSelect" class="block text-gray-700 text-sm font-bold mb-2">Select Model:</label>
21
+ <select id="modelSelect" class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline">
22
+ <!-- Options will be populated by JavaScript -->
23
+ </select>
24
+ </div>
25
+ <div class="flex items-center justify-between">
26
+ <button id="embedButton" class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded focus:outline-none focus:shadow-outline disabled:opacity-50 disabled:cursor-not-allowed">
27
+ Get Embedding
28
+ </button>
29
+ </div>
30
+ <div class="mt-6">
31
+ <label class="block text-gray-700 text-sm font-bold mb-2">Embedding Result:</label>
32
+ <div class="flex items-center bg-gray-200 rounded">
33
+ <div id="result" class="p-4 text-gray-800 text-sm overflow-auto max-h-60 flex-grow">
34
+ <p>Embedding result will appear here...</p>
35
+ </div>
36
+ <button id="copyButton" class="p-4 self-start text-gray-600 hover:text-gray-800 focus:outline-none">
37
+ <i class="fas fa-copy"></i>
38
+ </button>
39
+ </div>
40
+ </div>
41
+ </div>
42
+
43
+ <script>
44
+ document.addEventListener('DOMContentLoaded', async () => {
45
+ const modelSelect = document.getElementById('modelSelect');
46
+
47
+ try {
48
+ const response = await fetch('/v1/models');
49
+ if (!response.ok) {
50
+ throw new Error(`HTTP error! status: ${response.status}`);
51
+ }
52
+ const data = await response.json();
53
+
54
+ // Clear existing options
55
+ modelSelect.innerHTML = '';
56
+
57
+ // Populate dropdown with models
58
+ data.data.forEach(model => {
59
+ const option = document.createElement('option');
60
+ option.value = model.id;
61
+ option.textContent = model.id;
62
+ modelSelect.appendChild(option);
63
+ });
64
+
65
+ } catch (error) {
66
+ console.error('Error fetching models:', error);
67
+ // Optionally, add an error message to the dropdown or a separate element
68
+ const option = document.createElement('option');
69
+ option.value = '';
70
+ option.textContent = 'Error loading models';
71
+ modelSelect.appendChild(option);
72
+ }
73
+ });
74
+
75
+
76
+ document.getElementById('embedButton').addEventListener('click', async () => {
77
+ const inputText = document.getElementById('inputText').value;
78
+ const model = document.getElementById('modelSelect').value;
79
+ const resultDiv = document.getElementById('result');
80
+ const copyButton = document.getElementById('copyButton');
81
+ const embedButton = document.getElementById('embedButton'); // Get button reference
82
+
83
+ if (!inputText) {
84
+ resultDiv.textContent = 'Please enter some text.';
85
+ return;
86
+ }
87
+
88
+ resultDiv.textContent = 'Fetching embedding...';
89
+ copyButton.style.display = 'none'; // Hide copy button while fetching
90
+ embedButton.disabled = true; // Disable button
91
+
92
+ try {
93
+ const response = await fetch('/v1/embeddings', {
94
+ method: 'POST',
95
+ headers: {
96
+ 'Content-Type': 'application/json',
97
+ },
98
+ body: JSON.stringify({
99
+ input: inputText,
100
+ model: model,
101
+ encoding_format: 'float' // Assuming 'float' is the only supported format
102
+ }),
103
+ });
104
+
105
+ if (!response.ok) {
106
+ const error = await response.json();
107
+ throw new Error(`HTTP error! status: ${response.status}, detail: ${error.detail}`);
108
+ }
109
+
110
+ const data = await response.json();
111
+ resultDiv.textContent = JSON.stringify(data, null, 2); // Display pretty-printed JSON
112
+ copyButton.style.display = 'block'; // Show copy button after result is displayed
113
+
114
+ } catch (error) {
115
+ resultDiv.textContent = `Error: ${error.message}`;
116
+ console.error('Error fetching embedding:', error);
117
+ copyButton.style.display = 'none'; // Hide copy button on error
118
+ } finally {
119
+ embedButton.disabled = false; // Re-enable button
120
+ }
121
+ });
122
+
123
+ document.getElementById('copyButton').addEventListener('click', async () => {
124
+ const resultDiv = document.getElementById('result');
125
+ const textToCopy = resultDiv.textContent;
126
+
127
+ try {
128
+ await navigator.clipboard.writeText(textToCopy);
129
+ // Optional: Provide visual feedback to the user
130
+ alert('Copied to clipboard!');
131
+ } catch (err) {
132
+ console.error('Failed to copy text: ', err);
133
+ // Optional: Provide visual feedback to the user
134
+ alert('Failed to copy text.');
135
+ }
136
+ });
137
+
138
+ // Initially hide the copy button
139
+ document.getElementById('copyButton').style.display = 'none';
140
+ </script>
141
+ </body>
142
+ </html>