|
from fastapi import FastAPI |
|
from sentence_transformers import SentenceTransformer |
|
import pickle |
|
import os |
|
from pydantic import BaseModel |
|
import numpy as np |
|
from typing import List |
|
|
|
app = FastAPI( |
|
title="SBERT Embedding API", |
|
description="API for generating sentence embeddings using SBERT", |
|
version="1.0" |
|
) |
|
|
|
|
|
model_name = 'taghyan/model' |
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
embedding_file = 'embeddings_sbert.pkl' |
|
|
|
class TextRequest(BaseModel): |
|
text: str |
|
|
|
class TextsRequest(BaseModel): |
|
texts: List[str] |
|
|
|
class EmbeddingResponse(BaseModel): |
|
embedding: List[float] |
|
|
|
class EmbeddingsResponse(BaseModel): |
|
embeddings: List[List[float]] |
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"message": "SBERT Embedding Service"} |
|
|
|
@app.post("/embed", response_model=EmbeddingResponse) |
|
async def embed_text(request: TextRequest): |
|
"""Generate embedding for a single text""" |
|
embedding = model.encode(request.text, convert_to_numpy=True).tolist() |
|
return {"embedding": embedding} |
|
|
|
@app.post("/embed_batch", response_model=EmbeddingsResponse) |
|
async def embed_texts(request: TextsRequest): |
|
"""Generate embeddings for multiple texts""" |
|
embeddings = model.encode(request.texts, show_progress_bar=True, convert_to_numpy=True).tolist() |
|
return {"embeddings": embeddings} |
|
|
|
@app.post("/update_cache") |
|
async def update_cache(request: TextsRequest): |
|
"""Update the embedding cache with new texts""" |
|
if os.path.exists(embedding_file): |
|
with open(embedding_file, 'rb') as f: |
|
existing_embeddings = pickle.load(f) |
|
else: |
|
existing_embeddings = [] |
|
|
|
new_embeddings = model.encode(request.texts, show_progress_bar=True) |
|
updated_embeddings = existing_embeddings + new_embeddings.tolist() |
|
|
|
with open(embedding_file, 'wb') as f: |
|
pickle.dump(updated_embeddings, f) |
|
|
|
return {"message": f"Cache updated with {len(request.texts)} new embeddings", "total_embeddings": len(updated_embeddings)} |