Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Dict, Optional, Any, Mapping | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import cos_sim | |
from sentence_transformers.quantization import quantize_embeddings | |
import torch.nn.functional as F | |
# 1. Specify preffered dimensions | |
# dimensions = 512 | |
# # 2. load model | |
# model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=dimensions) | |
# # For retrieval you need to pass this prompt. | |
# query = 'Represent this sentence for searching relevant passages: A man is eating a piece of bread' | |
# docs = [ | |
# query, | |
# "A man is eating food.", | |
# "A man is eating pasta.", | |
# "The girl is carrying a baby.", | |
# "A man is riding a horse.", | |
# ] | |
# # 2. Encode | |
# embeddings = model.encode(docs) | |
# # Optional: Quantize the embeddings | |
# binary_embeddings = quantize_embeddings(embeddings, precision="ubinary") | |
# similarities = cos_sim(embeddings[0], embeddings[1:]) | |
# print('similarities:', similarities) | |
app = FastAPI() | |
class EmbeddingRequest(BaseModel): | |
model: str | |
prompt: str | |
options: Optional[dict] = None | |
normalize: bool = True | |
models = { | |
"mxbai-embed-large":SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1"), | |
"nomic-embed-text": SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True) | |
} | |
async def get_models(): | |
return { | |
k:dict(v.config) for k,v in models.items() | |
} | |
def greet_json(): | |
return {"Hello": "World!"} | |
async def get_embeddings(request: EmbeddingRequest): | |
try: | |
model = request.model | |
model_kwargs = request.options | |
# embedding = embeddings_instance._process_emb_response(request.prompt) | |
embeddings = models[model].encode(request.prompt, convert_to_tensor=True)#, **model_kwargs) | |
# embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],)) | |
# embeddings = F.normalize(embeddings, p=2, dim=1) | |
return {"embedding": embeddings.tolist()} | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) |