embeddings / app.py
jonathanjordan21's picture
Update app.py
ea4f0af verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Optional, Any, Mapping
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sentence_transformers.quantization import quantize_embeddings
import torch.nn.functional as F
# 1. Specify preffered dimensions
# dimensions = 512
# # 2. load model
# model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=dimensions)
# # For retrieval you need to pass this prompt.
# query = 'Represent this sentence for searching relevant passages: A man is eating a piece of bread'
# docs = [
# query,
# "A man is eating food.",
# "A man is eating pasta.",
# "The girl is carrying a baby.",
# "A man is riding a horse.",
# ]
# # 2. Encode
# embeddings = model.encode(docs)
# # Optional: Quantize the embeddings
# binary_embeddings = quantize_embeddings(embeddings, precision="ubinary")
# similarities = cos_sim(embeddings[0], embeddings[1:])
# print('similarities:', similarities)
app = FastAPI()
class EmbeddingRequest(BaseModel):
model: str
prompt: str
options: Optional[dict] = None
normalize: bool = True
models = {
"mxbai-embed-large":SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1"),
"nomic-embed-text": SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
}
@app.get("/api/tags")
async def get_models():
return {
k:dict(v.config) for k,v in models.items()
}
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/api/embeddings")
async def get_embeddings(request: EmbeddingRequest):
try:
model = request.model
model_kwargs = request.options
# embedding = embeddings_instance._process_emb_response(request.prompt)
embeddings = models[model].encode(request.prompt, convert_to_tensor=True)#, **model_kwargs)
# embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
# embeddings = F.normalize(embeddings, p=2, dim=1)
return {"embedding": embeddings.tolist()}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))