model2 / app.py
taghyan's picture
Create app.py
cbd9add verified
from fastapi import FastAPI
from sentence_transformers import SentenceTransformer
import pickle
import os
from pydantic import BaseModel
import numpy as np
from typing import List
app = FastAPI(
title="SBERT Embedding API",
description="API for generating sentence embeddings using SBERT",
version="1.0"
)
# Load model (this will be cached after first load)
model_name = 'taghyan/model'
model = SentenceTransformer(model_name)
# Embedding cache setup
embedding_file = 'embeddings_sbert.pkl'
class TextRequest(BaseModel):
text: str
class TextsRequest(BaseModel):
texts: List[str]
class EmbeddingResponse(BaseModel):
embedding: List[float]
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
@app.get("/")
def read_root():
return {"message": "SBERT Embedding Service"}
@app.post("/embed", response_model=EmbeddingResponse)
async def embed_text(request: TextRequest):
"""Generate embedding for a single text"""
embedding = model.encode(request.text, convert_to_numpy=True).tolist()
return {"embedding": embedding}
@app.post("/embed_batch", response_model=EmbeddingsResponse)
async def embed_texts(request: TextsRequest):
"""Generate embeddings for multiple texts"""
embeddings = model.encode(request.texts, show_progress_bar=True, convert_to_numpy=True).tolist()
return {"embeddings": embeddings}
@app.post("/update_cache")
async def update_cache(request: TextsRequest):
"""Update the embedding cache with new texts"""
if os.path.exists(embedding_file):
with open(embedding_file, 'rb') as f:
existing_embeddings = pickle.load(f)
else:
existing_embeddings = []
new_embeddings = model.encode(request.texts, show_progress_bar=True)
updated_embeddings = existing_embeddings + new_embeddings.tolist()
with open(embedding_file, 'wb') as f:
pickle.dump(updated_embeddings, f)
return {"message": f"Cache updated with {len(request.texts)} new embeddings", "total_embeddings": len(updated_embeddings)}