taghyan commited on
Commit
cbd9add
·
verified ·
1 Parent(s): 6aa0d0d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from sentence_transformers import SentenceTransformer
3
+ import pickle
4
+ import os
5
+ from pydantic import BaseModel
6
+ import numpy as np
7
+ from typing import List
8
+
9
+ app = FastAPI(
10
+ title="SBERT Embedding API",
11
+ description="API for generating sentence embeddings using SBERT",
12
+ version="1.0"
13
+ )
14
+
15
+ # Load model (this will be cached after first load)
16
+ model_name = 'taghyan/model'
17
+ model = SentenceTransformer(model_name)
18
+
19
+ # Embedding cache setup
20
+ embedding_file = 'embeddings_sbert.pkl'
21
+
22
+ class TextRequest(BaseModel):
23
+ text: str
24
+
25
+ class TextsRequest(BaseModel):
26
+ texts: List[str]
27
+
28
+ class EmbeddingResponse(BaseModel):
29
+ embedding: List[float]
30
+
31
+ class EmbeddingsResponse(BaseModel):
32
+ embeddings: List[List[float]]
33
+
34
+ @app.get("/")
35
+ def read_root():
36
+ return {"message": "SBERT Embedding Service"}
37
+
38
+ @app.post("/embed", response_model=EmbeddingResponse)
39
+ async def embed_text(request: TextRequest):
40
+ """Generate embedding for a single text"""
41
+ embedding = model.encode(request.text, convert_to_numpy=True).tolist()
42
+ return {"embedding": embedding}
43
+
44
+ @app.post("/embed_batch", response_model=EmbeddingsResponse)
45
+ async def embed_texts(request: TextsRequest):
46
+ """Generate embeddings for multiple texts"""
47
+ embeddings = model.encode(request.texts, show_progress_bar=True, convert_to_numpy=True).tolist()
48
+ return {"embeddings": embeddings}
49
+
50
+ @app.post("/update_cache")
51
+ async def update_cache(request: TextsRequest):
52
+ """Update the embedding cache with new texts"""
53
+ if os.path.exists(embedding_file):
54
+ with open(embedding_file, 'rb') as f:
55
+ existing_embeddings = pickle.load(f)
56
+ else:
57
+ existing_embeddings = []
58
+
59
+ new_embeddings = model.encode(request.texts, show_progress_bar=True)
60
+ updated_embeddings = existing_embeddings + new_embeddings.tolist()
61
+
62
+ with open(embedding_file, 'wb') as f:
63
+ pickle.dump(updated_embeddings, f)
64
+
65
+ return {"message": f"Cache updated with {len(request.texts)} new embeddings", "total_embeddings": len(updated_embeddings)}