import os import time import logging from typing import List, Union, Tuple from cachetools import LRUCache import hashlib import asyncio from functools import lru_cache from contextlib import asynccontextmanager os.environ["TOKENIZERS_PARALLELISM"] = "false" # Configure logging logging.basicConfig( level=( logging.DEBUG if os.environ.get("ENVIRONMENT") != "production" else logging.INFO ), format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.exceptions import RequestValidationError from pydantic import BaseModel, Field, field_validator, ConfigDict # Import ConfigDict from pydantic_settings import BaseSettings from transformers import AutoModel, AutoTokenizer import torch import torch.nn.functional as F import uvicorn from starlette import status from models_config import MODELS, get_model_config, CANONICAL_MODELS, MODEL_ALIASES # --- Configuration Management --- class AppSettings(BaseSettings): """ Application settings loaded from environment variables. """ cuda_cache_clear_enabled: bool = Field( True, json_schema_extra={"env": "CUDA_CACHE_CLEAR_ENABLED"}, description="Enable CUDA cache clearing after each batch." ) default_model: str = Field( "text-embedding-3-large", json_schema_extra={"env": "DEFAULT_MODEL"}, description="Default embedding model to use." ) warmup_enabled: bool = Field( True, json_schema_extra={"env": "WARMUP_ENABLED"}, description="Enable model warmup on startup." ) app_port: int = Field( 8000, json_schema_extra={"env": "APP_PORT"}, description="Port for the FastAPI application." ) app_host: str = Field( "0.0.0.0", json_schema_extra={"env": "APP_HOST"}, description="Host for the FastAPI application." ) embedding_batch_size: int = Field( 8, json_schema_extra={"env": "EMBEDDING_BATCH_SIZE"}, description="Batch size for embedding generation." ) embeddings_cache_enabled: bool = Field( True, json_schema_extra={"env": "EMBEDDINGS_CACHE_ENABLED"}, description="Enable in-memory embeddings cache." ) report_cached_tokens: bool = Field( False, json_schema_extra={"env": "REPORT_CACHED_TOKENS"}, description="Report token count for cached embeddings." ) embeddings_cache_maxsize: int = Field( 2048, json_schema_extra={"env": "EMBEDDINGS_CACHE_MAXSIZE"}, description="Maximum size of the embeddings cache." ) environment: str = Field( "development", json_schema_extra={"env": "ENVIRONMENT"}, description="Application environment (e.g., 'production', 'development')." ) model_config = ConfigDict(env_file=".env") # Use ConfigDict instead of class Config @lru_cache() # Cache the settings instance for performance def get_app_settings(): return AppSettings() # Set up device configuration DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {DEVICE}") # Initialize global embeddings cache (size determined by settings) embeddings_cache = LRUCache(maxsize=0) # Will be updated on startup based on settings # --- Lifespan Event Handler --- @asynccontextmanager async def lifespan(app: FastAPI): """ Handles application startup and shutdown events. Initializes the embeddings cache and warms up the default model. """ settings = get_app_settings() # Directly get settings here global embeddings_cache embeddings_cache = LRUCache(maxsize=settings.embeddings_cache_maxsize) logger.info(f"Embeddings cache initialized with max size: {settings.embeddings_cache_maxsize}") default_model = settings.default_model if default_model not in MODELS: logger.error(f"Default model '{default_model}' is not configured in MODELS.") raise ValueError( f"Default model '{default_model}' is not configured in MODELS." ) if settings.warmup_enabled: logger.info(f"Warming up default model: {default_model}...") try: # Pass settings to get_embeddings_batch await get_embeddings_batch(["warmup"], default_model, settings) logger.info("Model warmup complete.") except Exception as e: logger.error(f"Model warmup failed for {default_model}: {e}", exc_info=True) yield # Application starts here # Clean up code (if any) goes here when application shuts down logger.info("Application shutdown.") app = FastAPI( title="Embedding API", description="API for generating embeddings using a transformer model.", version="0.1.0", lifespan=lifespan # Assign the lifespan context manager ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) # Mount the static directory to serve index.html and other static files. app.mount("/static", StaticFiles(directory="static"), name="static") # Initialize model cache # to avoid reloading models on every request model_cache = {} tokenizer_cache = {} # New: Initialize global dictionary for model loading locks model_load_locks = {} async def load_model(model_name: str): """ Load model and tokenizer if not already loaded, with asynchronous locking. Args: model_name (str): The name of the model to load. Returns: tuple: A tuple containing the loaded model and tokenizer. """ config = get_model_config(model_name) canonical_hf_model_name = config["name"] async with model_load_locks.setdefault(canonical_hf_model_name, asyncio.Lock()): if canonical_hf_model_name not in model_cache: logger.info(f"Loading model: {canonical_hf_model_name}") model_path = config["name"] trust_remote = config.get("requires_remote_code", False) model_cache[canonical_hf_model_name] = AutoModel.from_pretrained( model_path, trust_remote_code=trust_remote ).to(DEVICE) model_cache[canonical_hf_model_name].eval() tokenizer_cache[canonical_hf_model_name] = AutoTokenizer.from_pretrained(model_path) logger.info(f"Model loaded: {canonical_hf_model_name}") return model_cache[canonical_hf_model_name], tokenizer_cache[canonical_hf_model_name] class EmbeddingRequest(BaseModel): """ Represents a request for generating embeddings. Attributes: input (Union[str, List[str]]): The input text to embed, can be a single string or a list of strings. model (str): The name of the model to use for embedding. encoding_format (str): The format of the embeddings. Currently only 'float' is supported. """ input: Union[str, List[str]] = Field( ..., description="The input text to embed, can be a single string or a list of strings.", json_schema_extra={"example": "This is an example sentence."}, ) model: str = Field( "text-embedding-3-large", description="The name of the model to use for embedding. Supports both original model names and OpenAI-compatible names.", json_schema_extra={"example": "text-embedding-3-large"}, ) encoding_format: str = Field( "float", description="The format of the embeddings. Currently only 'float' is supported.", json_schema_extra={"example": "float"}, ) @field_validator('model') @classmethod def validate_model(cls, value: str) -> str: if value not in MODELS: valid_models = list(CANONICAL_MODELS.keys()) + list(MODEL_ALIASES.keys()) raise ValueError(f"Model must be one of: {', '.join(sorted(valid_models))}") return value @field_validator('encoding_format') @classmethod def validate_encoding_format(cls, value: str) -> str: if value != "float": raise ValueError("Only 'float' encoding format is supported") return value class EmbeddingObject(BaseModel): """ Represents an embedding object. Attributes: object (str): The type of object, which is "embedding". embedding (List[float]): The embedding vector. index (int): The index of the embedding. """ object: str = "embedding" embedding: List[float] index: int class EmbeddingResponse(BaseModel): """ Represents the response containing a list of embedding objects. """ data: List[EmbeddingObject] model: str object: str = "list" usage: dict class ModelObject(BaseModel): """ Represents a single model object in the list of models. """ id: str object: str = "model" created: int owned_by: str class ListModelsResponse(BaseModel): """ Represents the response containing a list of available models. """ data: List[ModelObject] object: str = "list" # --- Helper functions for get_embeddings_batch refactoring --- def _process_texts_for_cache_and_batching( texts: List[str], model_config: dict, settings: AppSettings ) -> Tuple[List[torch.Tensor], int, List[str], List[int]]: """ Checks cache for each text and prepares texts for model processing. Returns cached embeddings, total cached tokens, texts to process, and their original indices. """ final_ordered_embeddings = [None] * len(texts) total_prompt_tokens = 0 texts_to_process_in_model = [] original_indices_for_model_output = [] canonical_hf_model_name = model_config["name"] for i, text in enumerate(texts): text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest() cache_key = (text_hash, canonical_hf_model_name) if settings.embeddings_cache_enabled and cache_key in embeddings_cache: cached_embedding, cached_tokens = embeddings_cache[cache_key] final_ordered_embeddings[i] = cached_embedding.unsqueeze(0) if settings.report_cached_tokens: total_prompt_tokens += cached_tokens logger.debug(f"Cache hit for text at index {i}") else: texts_to_process_in_model.append(text) original_indices_for_model_output.append(i) logger.debug(f"Cache miss for text at index {i}") return final_ordered_embeddings, total_prompt_tokens, texts_to_process_in_model, original_indices_for_model_output def _apply_instruction_prefix(texts: List[str], model_config: dict) -> List[str]: """ Applies instruction prefixes to texts if required by the model configuration. """ if model_config.get("instruction_prefix_required", False): processed_texts = [] default_prefix = model_config.get("default_instruction_prefix", "") known_prefixes = model_config.get("known_instruction_prefixes", []) for text in texts: if not any(text.startswith(prefix) for prefix in known_prefixes): processed_texts.append(f"{default_prefix}{text}") else: processed_texts.append(text) return processed_texts return texts def _perform_model_inference( texts_to_tokenize: List[str], model, tokenizer, model_max_tokens: int, model_dimension: int, settings: AppSettings ) -> Tuple[torch.Tensor, List[int], int]: """ Performs model inference for a batch of texts and returns embeddings, individual token counts, and total prompt tokens for the batch. Handles CUDA Out of Memory errors. """ try: batch_dict = tokenizer( texts_to_tokenize, max_length=model_max_tokens, padding=True, truncation=True, return_tensors="pt", ) individual_tokens_in_batch = [ int(torch.sum(mask).item()) for mask in batch_dict["attention_mask"] ] prompt_tokens_current_batch = int(torch.sum(batch_dict["attention_mask"]).item()) batch_dict = {k: v.to(DEVICE) for k, v in batch_dict.items()} with torch.no_grad(): outputs = model(**batch_dict) embeddings = outputs.last_hidden_state[:, 0] embeddings = embeddings[:, :model_dimension] embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings, individual_tokens_in_batch, prompt_tokens_current_batch except torch.cuda.OutOfMemoryError as e: logger.error( f"CUDA Out of Memory Error during embedding generation: {e}. " "Consider reducing EMBEDDING_BATCH_SIZE or using a smaller model.", exc_info=True ) raise HTTPException( status_code=status.HTTP_507_INSUFFICIENT_STORAGE, detail=f"GPU out of memory: {e}. Please try with a smaller batch size or a different model." ) except Exception as e: logger.error( f"An unexpected error occurred during batch embedding generation: {e}", exc_info=True ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error during embedding generation: {str(e)}" ) finally: if settings.cuda_cache_clear_enabled and torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("CUDA cache cleared after processing chunk.") def _store_embeddings_in_cache( embeddings: torch.Tensor, individual_tokens_in_batch: List[int], batch_original_indices: List[int], texts: List[str], model_config: dict, final_ordered_embeddings: List[Union[torch.Tensor, None]], settings: AppSettings ): """ Stores newly generated embeddings in the cache and updates the final ordered embeddings list. """ canonical_hf_model_name = model_config["name"] for j, original_idx in enumerate(batch_original_indices): current_text = texts[original_idx] current_embedding = embeddings[j].cpu() current_tokens = individual_tokens_in_batch[j] current_text_hash = hashlib.sha256(current_text.encode('utf-8')).hexdigest() if settings.embeddings_cache_enabled: embeddings_cache[(current_text_hash, canonical_hf_model_name)] = (current_embedding, current_tokens) final_ordered_embeddings[original_idx] = current_embedding.unsqueeze(0) async def get_embeddings_batch( texts: List[str], model_name: str, settings: AppSettings = Depends(get_app_settings) ) -> Tuple[torch.Tensor, int]: """ Generates embeddings for a batch of texts using the specified model. Handles potential CUDA out of memory errors by processing texts in chunks. Includes an in-memory cache for individual text-model pairs. Args: texts (List[str]): The list of input texts to embed. model_name (str): The name of the model to use. settings (AppSettings): Application settings injected via FastAPI's Depends. """ config = get_model_config(model_name) model, tokenizer = await load_model(model_name) model_max_tokens = config.get("max_tokens", 8192) model_dimension = config["dimension"] max_batch_size = settings.embedding_batch_size final_ordered_embeddings, total_prompt_tokens, texts_to_process_in_model, original_indices_for_model_output = \ _process_texts_for_cache_and_batching(texts, config, settings) if texts_to_process_in_model: for i in range(0, len(texts_to_process_in_model), max_batch_size): batch_texts = texts_to_process_in_model[i : i + max_batch_size] batch_original_indices = original_indices_for_model_output[i : i + max_batch_size] texts_to_tokenize = _apply_instruction_prefix(batch_texts, config) embeddings, individual_tokens_in_batch, prompt_tokens_current_batch = \ _perform_model_inference(texts_to_tokenize, model, tokenizer, model_max_tokens, model_dimension, settings) total_prompt_tokens += prompt_tokens_current_batch _store_embeddings_in_cache( embeddings, individual_tokens_in_batch, batch_original_indices, texts, config, final_ordered_embeddings, settings ) final_embeddings_tensor = torch.cat([e for e in final_ordered_embeddings if e is not None], dim=0) return final_embeddings_tensor, total_prompt_tokens @app.get("/", response_class=FileResponse) async def read_root(): """ Serve the static index.html file at the root route. """ return FileResponse("static/index.html") @app.get("/v1/models", response_model=ListModelsResponse) async def list_models(): """ Lists the available embedding models. Returns: ListModelsResponse: The response containing a list of model objects. """ model_list = [] current_time = int(time.time()) for model_name in MODELS.keys(): model_list.append( ModelObject( id=model_name, created=current_time, owned_by="local", ) ) return ListModelsResponse(data=model_list) @app.get("/v1/models/{model_id}", response_model=ModelObject) async def get_model(model_id: str): """ Retrieves information about a specific embedding model. Args: model_id (str): The ID of the model to retrieve. """ if model_id in MODELS: current_time = int(time.time()) return ModelObject( id=model_id, created=current_time, owned_by="local", ) else: raise HTTPException(status_code=404, detail="Model not found") @app.post( "/api/embed", response_model=EmbeddingResponse ) @app.post( "/v1/embeddings", response_model=EmbeddingResponse ) async def create_embeddings(request: EmbeddingRequest, settings: AppSettings = Depends(get_app_settings)): """ Generates embeddings for the given input text(s) using batch processing. Compatible with OpenAI's Embeddings API format. The input can be a single string or a list of strings. Returns a list of embedding objects, each containing the embedding vector. """ try: start_time = time.time() if isinstance(request.input, str): texts = [request.input] else: texts = request.input if not texts: return EmbeddingResponse( data=[], model=request.model, object="list", usage={"prompt_tokens": 0, "total_tokens": 0}, ) embeddings_tensor, total_tokens = await get_embeddings_batch(texts, request.model, settings) data = [ EmbeddingObject(embedding=embeddings_tensor[i].tolist(), index=i) for i in range(len(texts)) ] usage = { "prompt_tokens": total_tokens, "total_tokens": total_tokens, } end_time = time.time() processing_time = end_time - start_time if settings.environment != "production": logger.debug( f"Processed {len(texts)} inputs in {processing_time:.4f} seconds. " f"Model: {request.model}. Tokens: {total_tokens}." ) return EmbeddingResponse( data=data, model=request.model, object="list", usage=usage ) except ValueError as e: logger.error(f"Validation error in /v1/embeddings: {e}", exc_info=True) raise HTTPException(status_code=422, detail=str(e)) except HTTPException as e: logger.error(f"HTTPException in /v1/embeddings: {e.detail}", exc_info=True) raise e except Exception as e: logger.error(f"Unhandled error in /v1/embeddings: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): logger.error(f"Validation error for request to {request.url}: {exc.errors()}") raise HTTPException(status_code=422, detail=str(exc.errors())) if __name__ == "__main__": uvicorn.run(app, host=get_app_settings().app_host, port=get_app_settings().app_port)