Spaces:
Running
Running
Commit
·
231d431
0
Parent(s):
wip
Browse files- .gitattributes +35 -0
- Dockerfile +38 -0
- README.md +10 -0
- app.py +572 -0
- models_config.py +60 -0
- requirements.txt +10 -0
- static/index.html +142 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.13
|
2 |
+
|
3 |
+
# Install basic tools: wget, curl, unzip
|
4 |
+
RUN apt-get update && \
|
5 |
+
apt-get install -y wget curl unzip && \
|
6 |
+
rm -rf /var/lib/apt/lists/*
|
7 |
+
|
8 |
+
RUN useradd -m -u 1000 user
|
9 |
+
USER user
|
10 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
11 |
+
|
12 |
+
WORKDIR /app
|
13 |
+
|
14 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
15 |
+
|
16 |
+
RUN mkdir -p /home/user/.cache/pip
|
17 |
+
|
18 |
+
RUN chown user:user /home/user/.cache/pip
|
19 |
+
RUN --mount=type=cache,target=/home/user/.cache/pip pip install --upgrade -r requirements.txt
|
20 |
+
|
21 |
+
COPY --chown=user . /app
|
22 |
+
|
23 |
+
|
24 |
+
ENV APP_PORT=7860
|
25 |
+
ENV APP_HOST="0.0.0.0"
|
26 |
+
ENV ENVIRONMENT="production"
|
27 |
+
ENV TOKENIZERS_PARALLELISM=false
|
28 |
+
|
29 |
+
ENV DEFAULT_MODEL="text-embedding-3-large"
|
30 |
+
ENV WARMUP_ENABLED=true
|
31 |
+
ENV CUDA_CACHE_CLEAR_ENABLED=true
|
32 |
+
ENV EMBEDDING_BATCH_SIZE=8
|
33 |
+
ENV EMBEDDINGS_CACHE_ENABLED=true
|
34 |
+
ENV EMBEDDINGS_CACHE_MAXSIZE=2048
|
35 |
+
ENV REPORT_CACHED_TOKENS=false
|
36 |
+
|
37 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
38 |
+
# CMD ["python", "app.py"]
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Wip Test
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: gray
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
|
10 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
from typing import List, Union, Tuple
|
5 |
+
from cachetools import LRUCache
|
6 |
+
import hashlib
|
7 |
+
import asyncio
|
8 |
+
from functools import lru_cache
|
9 |
+
from contextlib import asynccontextmanager
|
10 |
+
|
11 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
12 |
+
|
13 |
+
# Configure logging
|
14 |
+
logging.basicConfig(
|
15 |
+
level=(
|
16 |
+
logging.DEBUG if os.environ.get("ENVIRONMENT") != "production" else logging.INFO
|
17 |
+
),
|
18 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
19 |
+
)
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
from fastapi import FastAPI, HTTPException, Request, Depends
|
23 |
+
from fastapi.staticfiles import StaticFiles
|
24 |
+
from fastapi.responses import FileResponse, JSONResponse
|
25 |
+
from fastapi.middleware.cors import CORSMiddleware
|
26 |
+
from fastapi.exceptions import RequestValidationError
|
27 |
+
from pydantic import BaseModel, Field, field_validator, ConfigDict # Import ConfigDict
|
28 |
+
from pydantic_settings import BaseSettings
|
29 |
+
from transformers import AutoModel, AutoTokenizer
|
30 |
+
import torch
|
31 |
+
import torch.nn.functional as F
|
32 |
+
import uvicorn
|
33 |
+
from starlette import status
|
34 |
+
|
35 |
+
|
36 |
+
from models_config import MODELS, get_model_config, CANONICAL_MODELS, MODEL_ALIASES
|
37 |
+
|
38 |
+
|
39 |
+
# --- Configuration Management ---
|
40 |
+
class AppSettings(BaseSettings):
|
41 |
+
"""
|
42 |
+
Application settings loaded from environment variables.
|
43 |
+
"""
|
44 |
+
cuda_cache_clear_enabled: bool = Field(
|
45 |
+
True, json_schema_extra={"env": "CUDA_CACHE_CLEAR_ENABLED"}, description="Enable CUDA cache clearing after each batch."
|
46 |
+
)
|
47 |
+
default_model: str = Field(
|
48 |
+
"text-embedding-3-large", json_schema_extra={"env": "DEFAULT_MODEL"}, description="Default embedding model to use."
|
49 |
+
)
|
50 |
+
warmup_enabled: bool = Field(
|
51 |
+
True, json_schema_extra={"env": "WARMUP_ENABLED"}, description="Enable model warmup on startup."
|
52 |
+
)
|
53 |
+
app_port: int = Field(
|
54 |
+
8000, json_schema_extra={"env": "APP_PORT"}, description="Port for the FastAPI application."
|
55 |
+
)
|
56 |
+
app_host: str = Field(
|
57 |
+
"0.0.0.0", json_schema_extra={"env": "APP_HOST"}, description="Host for the FastAPI application."
|
58 |
+
)
|
59 |
+
embedding_batch_size: int = Field(
|
60 |
+
8, json_schema_extra={"env": "EMBEDDING_BATCH_SIZE"}, description="Batch size for embedding generation."
|
61 |
+
)
|
62 |
+
embeddings_cache_enabled: bool = Field(
|
63 |
+
True, json_schema_extra={"env": "EMBEDDINGS_CACHE_ENABLED"}, description="Enable in-memory embeddings cache."
|
64 |
+
)
|
65 |
+
report_cached_tokens: bool = Field(
|
66 |
+
False, json_schema_extra={"env": "REPORT_CACHED_TOKENS"}, description="Report token count for cached embeddings."
|
67 |
+
)
|
68 |
+
embeddings_cache_maxsize: int = Field(
|
69 |
+
2048, json_schema_extra={"env": "EMBEDDINGS_CACHE_MAXSIZE"}, description="Maximum size of the embeddings cache."
|
70 |
+
)
|
71 |
+
environment: str = Field(
|
72 |
+
"development", json_schema_extra={"env": "ENVIRONMENT"}, description="Application environment (e.g., 'production', 'development')."
|
73 |
+
)
|
74 |
+
|
75 |
+
model_config = ConfigDict(env_file=".env") # Use ConfigDict instead of class Config
|
76 |
+
|
77 |
+
|
78 |
+
@lru_cache() # Cache the settings instance for performance
|
79 |
+
def get_app_settings():
|
80 |
+
return AppSettings()
|
81 |
+
|
82 |
+
# Set up device configuration
|
83 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
84 |
+
logger.info(f"Using device: {DEVICE}")
|
85 |
+
|
86 |
+
# Initialize global embeddings cache (size determined by settings)
|
87 |
+
embeddings_cache = LRUCache(maxsize=0) # Will be updated on startup based on settings
|
88 |
+
|
89 |
+
# --- Lifespan Event Handler ---
|
90 |
+
@asynccontextmanager
|
91 |
+
async def lifespan(app: FastAPI):
|
92 |
+
"""
|
93 |
+
Handles application startup and shutdown events.
|
94 |
+
Initializes the embeddings cache and warms up the default model.
|
95 |
+
"""
|
96 |
+
settings = get_app_settings() # Directly get settings here
|
97 |
+
global embeddings_cache
|
98 |
+
embeddings_cache = LRUCache(maxsize=settings.embeddings_cache_maxsize)
|
99 |
+
logger.info(f"Embeddings cache initialized with max size: {settings.embeddings_cache_maxsize}")
|
100 |
+
|
101 |
+
default_model = settings.default_model
|
102 |
+
if default_model not in MODELS:
|
103 |
+
logger.error(f"Default model '{default_model}' is not configured in MODELS.")
|
104 |
+
raise ValueError(
|
105 |
+
f"Default model '{default_model}' is not configured in MODELS."
|
106 |
+
)
|
107 |
+
if settings.warmup_enabled:
|
108 |
+
logger.info(f"Warming up default model: {default_model}...")
|
109 |
+
try:
|
110 |
+
# Pass settings to get_embeddings_batch
|
111 |
+
await get_embeddings_batch(["warmup"], default_model, settings)
|
112 |
+
logger.info("Model warmup complete.")
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Model warmup failed for {default_model}: {e}", exc_info=True)
|
115 |
+
|
116 |
+
yield # Application starts here
|
117 |
+
|
118 |
+
# Clean up code (if any) goes here when application shuts down
|
119 |
+
logger.info("Application shutdown.")
|
120 |
+
|
121 |
+
|
122 |
+
app = FastAPI(
|
123 |
+
title="Embedding API",
|
124 |
+
description="API for generating embeddings using a transformer model.",
|
125 |
+
version="0.1.0",
|
126 |
+
lifespan=lifespan # Assign the lifespan context manager
|
127 |
+
)
|
128 |
+
|
129 |
+
# Add CORS middleware
|
130 |
+
app.add_middleware(
|
131 |
+
CORSMiddleware,
|
132 |
+
allow_origins=["*"], # Allows all origins
|
133 |
+
allow_credentials=True,
|
134 |
+
allow_methods=["*"], # Allows all methods
|
135 |
+
allow_headers=["*"], # Allows all headers
|
136 |
+
)
|
137 |
+
|
138 |
+
# Mount the static directory to serve index.html and other static files.
|
139 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
140 |
+
|
141 |
+
# Initialize model cache
|
142 |
+
# to avoid reloading models on every request
|
143 |
+
model_cache = {}
|
144 |
+
tokenizer_cache = {}
|
145 |
+
|
146 |
+
# New: Initialize global dictionary for model loading locks
|
147 |
+
model_load_locks = {}
|
148 |
+
|
149 |
+
|
150 |
+
async def load_model(model_name: str):
|
151 |
+
"""
|
152 |
+
Load model and tokenizer if not already loaded, with asynchronous locking.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
model_name (str): The name of the model to load.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
tuple: A tuple containing the loaded model and tokenizer.
|
159 |
+
"""
|
160 |
+
config = get_model_config(model_name)
|
161 |
+
canonical_hf_model_name = config["name"]
|
162 |
+
|
163 |
+
async with model_load_locks.setdefault(canonical_hf_model_name, asyncio.Lock()):
|
164 |
+
if canonical_hf_model_name not in model_cache:
|
165 |
+
logger.info(f"Loading model: {canonical_hf_model_name}")
|
166 |
+
model_path = config["name"]
|
167 |
+
trust_remote = config.get("requires_remote_code", False)
|
168 |
+
|
169 |
+
model_cache[canonical_hf_model_name] = AutoModel.from_pretrained(
|
170 |
+
model_path, trust_remote_code=trust_remote
|
171 |
+
).to(DEVICE)
|
172 |
+
model_cache[canonical_hf_model_name].eval()
|
173 |
+
|
174 |
+
tokenizer_cache[canonical_hf_model_name] = AutoTokenizer.from_pretrained(model_path)
|
175 |
+
logger.info(f"Model loaded: {canonical_hf_model_name}")
|
176 |
+
return model_cache[canonical_hf_model_name], tokenizer_cache[canonical_hf_model_name]
|
177 |
+
|
178 |
+
|
179 |
+
class EmbeddingRequest(BaseModel):
|
180 |
+
"""
|
181 |
+
Represents a request for generating embeddings.
|
182 |
+
|
183 |
+
Attributes:
|
184 |
+
input (Union[str, List[str]]): The input text to embed, can be a single string or a list of strings.
|
185 |
+
model (str): The name of the model to use for embedding.
|
186 |
+
encoding_format (str): The format of the embeddings. Currently only 'float' is supported.
|
187 |
+
"""
|
188 |
+
|
189 |
+
input: Union[str, List[str]] = Field(
|
190 |
+
...,
|
191 |
+
description="The input text to embed, can be a single string or a list of strings.",
|
192 |
+
json_schema_extra={"example": "This is an example sentence."},
|
193 |
+
)
|
194 |
+
model: str = Field(
|
195 |
+
"text-embedding-3-large",
|
196 |
+
description="The name of the model to use for embedding. Supports both original model names and OpenAI-compatible names.",
|
197 |
+
json_schema_extra={"example": "text-embedding-3-large"},
|
198 |
+
)
|
199 |
+
encoding_format: str = Field(
|
200 |
+
"float",
|
201 |
+
description="The format of the embeddings. Currently only 'float' is supported.",
|
202 |
+
json_schema_extra={"example": "float"},
|
203 |
+
)
|
204 |
+
|
205 |
+
@field_validator('model')
|
206 |
+
@classmethod
|
207 |
+
def validate_model(cls, value: str) -> str:
|
208 |
+
if value not in MODELS:
|
209 |
+
valid_models = list(CANONICAL_MODELS.keys()) + list(MODEL_ALIASES.keys())
|
210 |
+
raise ValueError(f"Model must be one of: {', '.join(sorted(valid_models))}")
|
211 |
+
return value
|
212 |
+
|
213 |
+
@field_validator('encoding_format')
|
214 |
+
@classmethod
|
215 |
+
def validate_encoding_format(cls, value: str) -> str:
|
216 |
+
if value != "float":
|
217 |
+
raise ValueError("Only 'float' encoding format is supported")
|
218 |
+
return value
|
219 |
+
|
220 |
+
|
221 |
+
class EmbeddingObject(BaseModel):
|
222 |
+
"""
|
223 |
+
Represents an embedding object.
|
224 |
+
|
225 |
+
Attributes:
|
226 |
+
object (str): The type of object, which is "embedding".
|
227 |
+
embedding (List[float]): The embedding vector.
|
228 |
+
index (int): The index of the embedding.
|
229 |
+
"""
|
230 |
+
|
231 |
+
object: str = "embedding"
|
232 |
+
embedding: List[float]
|
233 |
+
index: int
|
234 |
+
|
235 |
+
|
236 |
+
class EmbeddingResponse(BaseModel):
|
237 |
+
"""
|
238 |
+
Represents the response containing a list of embedding objects.
|
239 |
+
"""
|
240 |
+
|
241 |
+
data: List[EmbeddingObject]
|
242 |
+
model: str
|
243 |
+
object: str = "list"
|
244 |
+
usage: dict
|
245 |
+
|
246 |
+
|
247 |
+
class ModelObject(BaseModel):
|
248 |
+
"""
|
249 |
+
Represents a single model object in the list of models.
|
250 |
+
"""
|
251 |
+
|
252 |
+
id: str
|
253 |
+
object: str = "model"
|
254 |
+
created: int
|
255 |
+
owned_by: str
|
256 |
+
|
257 |
+
|
258 |
+
class ListModelsResponse(BaseModel):
|
259 |
+
"""
|
260 |
+
Represents the response containing a list of available models.
|
261 |
+
"""
|
262 |
+
|
263 |
+
data: List[ModelObject]
|
264 |
+
object: str = "list"
|
265 |
+
|
266 |
+
|
267 |
+
# --- Helper functions for get_embeddings_batch refactoring ---
|
268 |
+
|
269 |
+
def _process_texts_for_cache_and_batching(
|
270 |
+
texts: List[str],
|
271 |
+
model_config: dict,
|
272 |
+
settings: AppSettings
|
273 |
+
) -> Tuple[List[torch.Tensor], int, List[str], List[int]]:
|
274 |
+
"""
|
275 |
+
Checks cache for each text and prepares texts for model processing.
|
276 |
+
Returns cached embeddings, total cached tokens, texts to process, and their original indices.
|
277 |
+
"""
|
278 |
+
final_ordered_embeddings = [None] * len(texts)
|
279 |
+
total_prompt_tokens = 0
|
280 |
+
texts_to_process_in_model = []
|
281 |
+
original_indices_for_model_output = []
|
282 |
+
|
283 |
+
canonical_hf_model_name = model_config["name"]
|
284 |
+
|
285 |
+
for i, text in enumerate(texts):
|
286 |
+
text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
|
287 |
+
cache_key = (text_hash, canonical_hf_model_name)
|
288 |
+
|
289 |
+
if settings.embeddings_cache_enabled and cache_key in embeddings_cache:
|
290 |
+
cached_embedding, cached_tokens = embeddings_cache[cache_key]
|
291 |
+
final_ordered_embeddings[i] = cached_embedding.unsqueeze(0)
|
292 |
+
if settings.report_cached_tokens:
|
293 |
+
total_prompt_tokens += cached_tokens
|
294 |
+
logger.debug(f"Cache hit for text at index {i}")
|
295 |
+
else:
|
296 |
+
texts_to_process_in_model.append(text)
|
297 |
+
original_indices_for_model_output.append(i)
|
298 |
+
logger.debug(f"Cache miss for text at index {i}")
|
299 |
+
return final_ordered_embeddings, total_prompt_tokens, texts_to_process_in_model, original_indices_for_model_output
|
300 |
+
|
301 |
+
def _apply_instruction_prefix(texts: List[str], model_config: dict) -> List[str]:
|
302 |
+
"""
|
303 |
+
Applies instruction prefixes to texts if required by the model configuration.
|
304 |
+
"""
|
305 |
+
if model_config.get("instruction_prefix_required", False):
|
306 |
+
processed_texts = []
|
307 |
+
default_prefix = model_config.get("default_instruction_prefix", "")
|
308 |
+
known_prefixes = model_config.get("known_instruction_prefixes", [])
|
309 |
+
for text in texts:
|
310 |
+
if not any(text.startswith(prefix) for prefix in known_prefixes):
|
311 |
+
processed_texts.append(f"{default_prefix}{text}")
|
312 |
+
else:
|
313 |
+
processed_texts.append(text)
|
314 |
+
return processed_texts
|
315 |
+
return texts
|
316 |
+
|
317 |
+
def _perform_model_inference(
|
318 |
+
texts_to_tokenize: List[str],
|
319 |
+
model,
|
320 |
+
tokenizer,
|
321 |
+
model_max_tokens: int,
|
322 |
+
model_dimension: int,
|
323 |
+
settings: AppSettings
|
324 |
+
) -> Tuple[torch.Tensor, List[int], int]:
|
325 |
+
"""
|
326 |
+
Performs model inference for a batch of texts and returns embeddings,
|
327 |
+
individual token counts, and total prompt tokens for the batch.
|
328 |
+
Handles CUDA Out of Memory errors.
|
329 |
+
"""
|
330 |
+
try:
|
331 |
+
batch_dict = tokenizer(
|
332 |
+
texts_to_tokenize,
|
333 |
+
max_length=model_max_tokens,
|
334 |
+
padding=True,
|
335 |
+
truncation=True,
|
336 |
+
return_tensors="pt",
|
337 |
+
)
|
338 |
+
|
339 |
+
individual_tokens_in_batch = [
|
340 |
+
int(torch.sum(mask).item()) for mask in batch_dict["attention_mask"]
|
341 |
+
]
|
342 |
+
|
343 |
+
prompt_tokens_current_batch = int(torch.sum(batch_dict["attention_mask"]).item())
|
344 |
+
|
345 |
+
batch_dict = {k: v.to(DEVICE) for k, v in batch_dict.items()}
|
346 |
+
|
347 |
+
with torch.no_grad():
|
348 |
+
outputs = model(**batch_dict)
|
349 |
+
|
350 |
+
embeddings = outputs.last_hidden_state[:, 0]
|
351 |
+
embeddings = embeddings[:, :model_dimension]
|
352 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
353 |
+
|
354 |
+
return embeddings, individual_tokens_in_batch, prompt_tokens_current_batch
|
355 |
+
except torch.cuda.OutOfMemoryError as e:
|
356 |
+
logger.error(
|
357 |
+
f"CUDA Out of Memory Error during embedding generation: {e}. "
|
358 |
+
"Consider reducing EMBEDDING_BATCH_SIZE or using a smaller model.",
|
359 |
+
exc_info=True
|
360 |
+
)
|
361 |
+
raise HTTPException(
|
362 |
+
status_code=status.HTTP_507_INSUFFICIENT_STORAGE,
|
363 |
+
detail=f"GPU out of memory: {e}. Please try with a smaller batch size or a different model."
|
364 |
+
)
|
365 |
+
except Exception as e:
|
366 |
+
logger.error(
|
367 |
+
f"An unexpected error occurred during batch embedding generation: {e}", exc_info=True
|
368 |
+
)
|
369 |
+
raise HTTPException(
|
370 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
371 |
+
detail=f"Internal server error during embedding generation: {str(e)}"
|
372 |
+
)
|
373 |
+
finally:
|
374 |
+
if settings.cuda_cache_clear_enabled and torch.cuda.is_available():
|
375 |
+
torch.cuda.empty_cache()
|
376 |
+
logger.debug("CUDA cache cleared after processing chunk.")
|
377 |
+
|
378 |
+
def _store_embeddings_in_cache(
|
379 |
+
embeddings: torch.Tensor,
|
380 |
+
individual_tokens_in_batch: List[int],
|
381 |
+
batch_original_indices: List[int],
|
382 |
+
texts: List[str],
|
383 |
+
model_config: dict,
|
384 |
+
final_ordered_embeddings: List[Union[torch.Tensor, None]],
|
385 |
+
settings: AppSettings
|
386 |
+
):
|
387 |
+
"""
|
388 |
+
Stores newly generated embeddings in the cache and updates the final ordered embeddings list.
|
389 |
+
"""
|
390 |
+
canonical_hf_model_name = model_config["name"]
|
391 |
+
for j, original_idx in enumerate(batch_original_indices):
|
392 |
+
current_text = texts[original_idx]
|
393 |
+
current_embedding = embeddings[j].cpu()
|
394 |
+
current_tokens = individual_tokens_in_batch[j]
|
395 |
+
|
396 |
+
current_text_hash = hashlib.sha256(current_text.encode('utf-8')).hexdigest()
|
397 |
+
if settings.embeddings_cache_enabled:
|
398 |
+
embeddings_cache[(current_text_hash, canonical_hf_model_name)] = (current_embedding, current_tokens)
|
399 |
+
final_ordered_embeddings[original_idx] = current_embedding.unsqueeze(0)
|
400 |
+
|
401 |
+
|
402 |
+
async def get_embeddings_batch(
|
403 |
+
texts: List[str],
|
404 |
+
model_name: str,
|
405 |
+
settings: AppSettings = Depends(get_app_settings)
|
406 |
+
) -> Tuple[torch.Tensor, int]:
|
407 |
+
"""
|
408 |
+
Generates embeddings for a batch of texts using the specified model.
|
409 |
+
Handles potential CUDA out of memory errors by processing texts in chunks.
|
410 |
+
Includes an in-memory cache for individual text-model pairs.
|
411 |
+
|
412 |
+
Args:
|
413 |
+
texts (List[str]): The list of input texts to embed.
|
414 |
+
model_name (str): The name of the model to use.
|
415 |
+
settings (AppSettings): Application settings injected via FastAPI's Depends.
|
416 |
+
"""
|
417 |
+
config = get_model_config(model_name)
|
418 |
+
model, tokenizer = await load_model(model_name)
|
419 |
+
|
420 |
+
model_max_tokens = config.get("max_tokens", 8192)
|
421 |
+
model_dimension = config["dimension"]
|
422 |
+
max_batch_size = settings.embedding_batch_size
|
423 |
+
|
424 |
+
final_ordered_embeddings, total_prompt_tokens, texts_to_process_in_model, original_indices_for_model_output = \
|
425 |
+
_process_texts_for_cache_and_batching(texts, config, settings)
|
426 |
+
|
427 |
+
if texts_to_process_in_model:
|
428 |
+
for i in range(0, len(texts_to_process_in_model), max_batch_size):
|
429 |
+
batch_texts = texts_to_process_in_model[i : i + max_batch_size]
|
430 |
+
batch_original_indices = original_indices_for_model_output[i : i + max_batch_size]
|
431 |
+
|
432 |
+
texts_to_tokenize = _apply_instruction_prefix(batch_texts, config)
|
433 |
+
|
434 |
+
embeddings, individual_tokens_in_batch, prompt_tokens_current_batch = \
|
435 |
+
_perform_model_inference(texts_to_tokenize, model, tokenizer, model_max_tokens, model_dimension, settings)
|
436 |
+
|
437 |
+
total_prompt_tokens += prompt_tokens_current_batch
|
438 |
+
|
439 |
+
_store_embeddings_in_cache(
|
440 |
+
embeddings,
|
441 |
+
individual_tokens_in_batch,
|
442 |
+
batch_original_indices,
|
443 |
+
texts,
|
444 |
+
config,
|
445 |
+
final_ordered_embeddings,
|
446 |
+
settings
|
447 |
+
)
|
448 |
+
|
449 |
+
final_embeddings_tensor = torch.cat([e for e in final_ordered_embeddings if e is not None], dim=0)
|
450 |
+
return final_embeddings_tensor, total_prompt_tokens
|
451 |
+
|
452 |
+
|
453 |
+
@app.get("/", response_class=FileResponse)
|
454 |
+
async def read_root():
|
455 |
+
"""
|
456 |
+
Serve the static index.html file at the root route.
|
457 |
+
"""
|
458 |
+
return FileResponse("static/index.html")
|
459 |
+
|
460 |
+
|
461 |
+
@app.get("/v1/models", response_model=ListModelsResponse)
|
462 |
+
async def list_models():
|
463 |
+
"""
|
464 |
+
Lists the available embedding models.
|
465 |
+
Returns:
|
466 |
+
ListModelsResponse: The response containing a list of model objects.
|
467 |
+
"""
|
468 |
+
model_list = []
|
469 |
+
current_time = int(time.time())
|
470 |
+
for model_name in MODELS.keys():
|
471 |
+
model_list.append(
|
472 |
+
ModelObject(
|
473 |
+
id=model_name,
|
474 |
+
created=current_time,
|
475 |
+
owned_by="local",
|
476 |
+
)
|
477 |
+
)
|
478 |
+
|
479 |
+
return ListModelsResponse(data=model_list)
|
480 |
+
|
481 |
+
|
482 |
+
@app.get("/v1/models/{model_id}", response_model=ModelObject)
|
483 |
+
async def get_model(model_id: str):
|
484 |
+
"""
|
485 |
+
Retrieves information about a specific embedding model.
|
486 |
+
Args:
|
487 |
+
model_id (str): The ID of the model to retrieve.
|
488 |
+
"""
|
489 |
+
if model_id in MODELS:
|
490 |
+
current_time = int(time.time())
|
491 |
+
return ModelObject(
|
492 |
+
id=model_id,
|
493 |
+
created=current_time,
|
494 |
+
owned_by="local",
|
495 |
+
)
|
496 |
+
else:
|
497 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
498 |
+
|
499 |
+
|
500 |
+
@app.post(
|
501 |
+
"/api/embed", response_model=EmbeddingResponse
|
502 |
+
)
|
503 |
+
@app.post(
|
504 |
+
"/v1/embeddings", response_model=EmbeddingResponse
|
505 |
+
)
|
506 |
+
async def create_embeddings(request: EmbeddingRequest, settings: AppSettings = Depends(get_app_settings)):
|
507 |
+
"""
|
508 |
+
Generates embeddings for the given input text(s) using batch processing.
|
509 |
+
Compatible with OpenAI's Embeddings API format.
|
510 |
+
The input can be a single string or a list of strings.
|
511 |
+
Returns a list of embedding objects, each containing the embedding vector.
|
512 |
+
"""
|
513 |
+
try:
|
514 |
+
start_time = time.time()
|
515 |
+
|
516 |
+
if isinstance(request.input, str):
|
517 |
+
texts = [request.input]
|
518 |
+
else:
|
519 |
+
texts = request.input
|
520 |
+
|
521 |
+
if not texts:
|
522 |
+
return EmbeddingResponse(
|
523 |
+
data=[],
|
524 |
+
model=request.model,
|
525 |
+
object="list",
|
526 |
+
usage={"prompt_tokens": 0, "total_tokens": 0},
|
527 |
+
)
|
528 |
+
|
529 |
+
embeddings_tensor, total_tokens = await get_embeddings_batch(texts, request.model, settings)
|
530 |
+
|
531 |
+
data = [
|
532 |
+
EmbeddingObject(embedding=embeddings_tensor[i].tolist(), index=i)
|
533 |
+
for i in range(len(texts))
|
534 |
+
]
|
535 |
+
|
536 |
+
usage = {
|
537 |
+
"prompt_tokens": total_tokens,
|
538 |
+
"total_tokens": total_tokens,
|
539 |
+
}
|
540 |
+
|
541 |
+
end_time = time.time()
|
542 |
+
processing_time = end_time - start_time
|
543 |
+
|
544 |
+
if settings.environment != "production":
|
545 |
+
logger.debug(
|
546 |
+
f"Processed {len(texts)} inputs in {processing_time:.4f} seconds. "
|
547 |
+
f"Model: {request.model}. Tokens: {total_tokens}."
|
548 |
+
)
|
549 |
+
|
550 |
+
return EmbeddingResponse(
|
551 |
+
data=data, model=request.model, object="list", usage=usage
|
552 |
+
)
|
553 |
+
|
554 |
+
except ValueError as e:
|
555 |
+
logger.error(f"Validation error in /v1/embeddings: {e}", exc_info=True)
|
556 |
+
raise HTTPException(status_code=422, detail=str(e))
|
557 |
+
except HTTPException as e:
|
558 |
+
logger.error(f"HTTPException in /v1/embeddings: {e.detail}", exc_info=True)
|
559 |
+
raise e
|
560 |
+
except Exception as e:
|
561 |
+
logger.error(f"Unhandled error in /v1/embeddings: {e}", exc_info=True)
|
562 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
563 |
+
|
564 |
+
|
565 |
+
@app.exception_handler(RequestValidationError)
|
566 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
567 |
+
logger.error(f"Validation error for request to {request.url}: {exc.errors()}")
|
568 |
+
raise HTTPException(status_code=422, detail=str(exc.errors()))
|
569 |
+
|
570 |
+
|
571 |
+
if __name__ == "__main__":
|
572 |
+
uvicorn.run(app, host=get_app_settings().app_host, port=get_app_settings().app_port)
|
models_config.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models_config.py
|
2 |
+
|
3 |
+
CANONICAL_MODELS = {
|
4 |
+
"all-MiniLM-L6-v2": {
|
5 |
+
"name": "sentence-transformers/all-MiniLM-L6-v2",
|
6 |
+
"dimension": 384,
|
7 |
+
"requires_remote_code": False,
|
8 |
+
"max_tokens": 512,
|
9 |
+
},
|
10 |
+
"gte-multilingual-base": {
|
11 |
+
"name": "Alibaba-NLP/gte-multilingual-base",
|
12 |
+
"dimension": 768,
|
13 |
+
"requires_remote_code": True,
|
14 |
+
"max_tokens": 8192,
|
15 |
+
},
|
16 |
+
"nomic-embed-text-v1.5": {
|
17 |
+
"name": "nomic-ai/nomic-embed-text-v1.5",
|
18 |
+
"dimension": 768,
|
19 |
+
"requires_remote_code": True,
|
20 |
+
"max_tokens": 8192,
|
21 |
+
"instruction_prefix_required": True,
|
22 |
+
"default_instruction_prefix": "search_document:",
|
23 |
+
"known_instruction_prefixes": [
|
24 |
+
"search_document:",
|
25 |
+
"search_query:",
|
26 |
+
"clustering:",
|
27 |
+
"classification:",
|
28 |
+
],
|
29 |
+
},
|
30 |
+
"all-mpnet-base-v2": {
|
31 |
+
"name": "sentence-transformers/all-mpnet-base-v2",
|
32 |
+
"dimension": 768,
|
33 |
+
"requires_remote_code": False,
|
34 |
+
"max_tokens": 384,
|
35 |
+
},
|
36 |
+
}
|
37 |
+
|
38 |
+
# Mapping of aliases to their canonical model names
|
39 |
+
MODEL_ALIASES = {
|
40 |
+
"all-minilm": "all-MiniLM-L6-v2",
|
41 |
+
"text-embedding-3-small": "all-MiniLM-L6-v2",
|
42 |
+
"text-embedding-3-large": "gte-multilingual-base",
|
43 |
+
"nomic-embed-text": "nomic-embed-text-v1.5",
|
44 |
+
}
|
45 |
+
|
46 |
+
# This global MODELS dictionary will be used for listing available models and validation.
|
47 |
+
# It combines canonical names and aliases for easy lookup.
|
48 |
+
MODELS = {**CANONICAL_MODELS, **{alias: CANONICAL_MODELS[canonical] for alias, canonical in MODEL_ALIASES.items()}}
|
49 |
+
|
50 |
+
def get_model_config(requested_model_name: str) -> dict:
|
51 |
+
"""
|
52 |
+
Resolves a requested model name (which might be an alias) to its canonical
|
53 |
+
configuration. Raises ValueError if the model is not found.
|
54 |
+
"""
|
55 |
+
canonical_name = MODEL_ALIASES.get(requested_model_name, requested_model_name)
|
56 |
+
|
57 |
+
if canonical_name not in CANONICAL_MODELS:
|
58 |
+
raise ValueError(f"Model '{requested_model_name}' (canonical: '{canonical_name}') is not a recognized model.")
|
59 |
+
|
60 |
+
return CANONICAL_MODELS[canonical_name]
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.8.1
|
2 |
+
cachetools==6.0.0
|
3 |
+
fastapi==0.115.12
|
4 |
+
httpx==0.28.1
|
5 |
+
pydantic==2.11.4
|
6 |
+
pydantic-settings==2.9.1
|
7 |
+
pytest==8.3.5
|
8 |
+
torch==2.7.0
|
9 |
+
transformers==4.51.3
|
10 |
+
uvicorn==0.34.2
|
static/index.html
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Embedding Demo</title>
|
7 |
+
<!-- Include Tailwind CSS via CDN -->
|
8 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
9 |
+
<!-- Include Heroicons for the copy icon -->
|
10 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
|
11 |
+
</head>
|
12 |
+
<body class="bg-gray-100 flex items-center justify-center min-h-screen">
|
13 |
+
<div class="bg-white p-8 rounded-lg shadow-md w-full max-w-md">
|
14 |
+
<h1 class="text-2xl font-bold mb-6 text-center">Embedding Demo</h1>
|
15 |
+
<div class="mb-4">
|
16 |
+
<label for="inputText" class="block text-gray-700 text-sm font-bold mb-2">Enter Text:</label>
|
17 |
+
<textarea id="inputText" class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline" rows="4" placeholder="Enter text to embed..."></textarea>
|
18 |
+
</div>
|
19 |
+
<div class="mb-6">
|
20 |
+
<label for="modelSelect" class="block text-gray-700 text-sm font-bold mb-2">Select Model:</label>
|
21 |
+
<select id="modelSelect" class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline">
|
22 |
+
<!-- Options will be populated by JavaScript -->
|
23 |
+
</select>
|
24 |
+
</div>
|
25 |
+
<div class="flex items-center justify-between">
|
26 |
+
<button id="embedButton" class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded focus:outline-none focus:shadow-outline disabled:opacity-50 disabled:cursor-not-allowed">
|
27 |
+
Get Embedding
|
28 |
+
</button>
|
29 |
+
</div>
|
30 |
+
<div class="mt-6">
|
31 |
+
<label class="block text-gray-700 text-sm font-bold mb-2">Embedding Result:</label>
|
32 |
+
<div class="flex items-center bg-gray-200 rounded">
|
33 |
+
<div id="result" class="p-4 text-gray-800 text-sm overflow-auto max-h-60 flex-grow">
|
34 |
+
<p>Embedding result will appear here...</p>
|
35 |
+
</div>
|
36 |
+
<button id="copyButton" class="p-4 self-start text-gray-600 hover:text-gray-800 focus:outline-none">
|
37 |
+
<i class="fas fa-copy"></i>
|
38 |
+
</button>
|
39 |
+
</div>
|
40 |
+
</div>
|
41 |
+
</div>
|
42 |
+
|
43 |
+
<script>
|
44 |
+
document.addEventListener('DOMContentLoaded', async () => {
|
45 |
+
const modelSelect = document.getElementById('modelSelect');
|
46 |
+
|
47 |
+
try {
|
48 |
+
const response = await fetch('/v1/models');
|
49 |
+
if (!response.ok) {
|
50 |
+
throw new Error(`HTTP error! status: ${response.status}`);
|
51 |
+
}
|
52 |
+
const data = await response.json();
|
53 |
+
|
54 |
+
// Clear existing options
|
55 |
+
modelSelect.innerHTML = '';
|
56 |
+
|
57 |
+
// Populate dropdown with models
|
58 |
+
data.data.forEach(model => {
|
59 |
+
const option = document.createElement('option');
|
60 |
+
option.value = model.id;
|
61 |
+
option.textContent = model.id;
|
62 |
+
modelSelect.appendChild(option);
|
63 |
+
});
|
64 |
+
|
65 |
+
} catch (error) {
|
66 |
+
console.error('Error fetching models:', error);
|
67 |
+
// Optionally, add an error message to the dropdown or a separate element
|
68 |
+
const option = document.createElement('option');
|
69 |
+
option.value = '';
|
70 |
+
option.textContent = 'Error loading models';
|
71 |
+
modelSelect.appendChild(option);
|
72 |
+
}
|
73 |
+
});
|
74 |
+
|
75 |
+
|
76 |
+
document.getElementById('embedButton').addEventListener('click', async () => {
|
77 |
+
const inputText = document.getElementById('inputText').value;
|
78 |
+
const model = document.getElementById('modelSelect').value;
|
79 |
+
const resultDiv = document.getElementById('result');
|
80 |
+
const copyButton = document.getElementById('copyButton');
|
81 |
+
const embedButton = document.getElementById('embedButton'); // Get button reference
|
82 |
+
|
83 |
+
if (!inputText) {
|
84 |
+
resultDiv.textContent = 'Please enter some text.';
|
85 |
+
return;
|
86 |
+
}
|
87 |
+
|
88 |
+
resultDiv.textContent = 'Fetching embedding...';
|
89 |
+
copyButton.style.display = 'none'; // Hide copy button while fetching
|
90 |
+
embedButton.disabled = true; // Disable button
|
91 |
+
|
92 |
+
try {
|
93 |
+
const response = await fetch('/v1/embeddings', {
|
94 |
+
method: 'POST',
|
95 |
+
headers: {
|
96 |
+
'Content-Type': 'application/json',
|
97 |
+
},
|
98 |
+
body: JSON.stringify({
|
99 |
+
input: inputText,
|
100 |
+
model: model,
|
101 |
+
encoding_format: 'float' // Assuming 'float' is the only supported format
|
102 |
+
}),
|
103 |
+
});
|
104 |
+
|
105 |
+
if (!response.ok) {
|
106 |
+
const error = await response.json();
|
107 |
+
throw new Error(`HTTP error! status: ${response.status}, detail: ${error.detail}`);
|
108 |
+
}
|
109 |
+
|
110 |
+
const data = await response.json();
|
111 |
+
resultDiv.textContent = JSON.stringify(data, null, 2); // Display pretty-printed JSON
|
112 |
+
copyButton.style.display = 'block'; // Show copy button after result is displayed
|
113 |
+
|
114 |
+
} catch (error) {
|
115 |
+
resultDiv.textContent = `Error: ${error.message}`;
|
116 |
+
console.error('Error fetching embedding:', error);
|
117 |
+
copyButton.style.display = 'none'; // Hide copy button on error
|
118 |
+
} finally {
|
119 |
+
embedButton.disabled = false; // Re-enable button
|
120 |
+
}
|
121 |
+
});
|
122 |
+
|
123 |
+
document.getElementById('copyButton').addEventListener('click', async () => {
|
124 |
+
const resultDiv = document.getElementById('result');
|
125 |
+
const textToCopy = resultDiv.textContent;
|
126 |
+
|
127 |
+
try {
|
128 |
+
await navigator.clipboard.writeText(textToCopy);
|
129 |
+
// Optional: Provide visual feedback to the user
|
130 |
+
alert('Copied to clipboard!');
|
131 |
+
} catch (err) {
|
132 |
+
console.error('Failed to copy text: ', err);
|
133 |
+
// Optional: Provide visual feedback to the user
|
134 |
+
alert('Failed to copy text.');
|
135 |
+
}
|
136 |
+
});
|
137 |
+
|
138 |
+
// Initially hide the copy button
|
139 |
+
document.getElementById('copyButton').style.display = 'none';
|
140 |
+
</script>
|
141 |
+
</body>
|
142 |
+
</html>
|