Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import logging | |
| # Set up FastAPI app | |
| app = FastAPI() | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5") | |
| model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5") | |
| # Precompute embeddings for labels | |
| labels = ["Mathematics", "Language Arts", "Social Studies", "Science"] | |
| label_embeddings = [] | |
| for label in labels: | |
| tokens = tokenizer(label, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| embedding = model(**tokens).last_hidden_state.mean(dim=1) | |
| label_embeddings.append(embedding) | |
| label_embeddings = torch.vstack(label_embeddings) | |
| async def root(): | |
| return {"message": "Welcome to the Zero-Shot Classification API"} | |
| async def predict(data: dict): | |
| logging.info(f"Received data: {data}") | |
| text = data["data"][0] | |
| # Compute embedding for input text | |
| tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| text_embedding = model(**tokens).last_hidden_state.mean(dim=1) | |
| # Compute cosine similarity | |
| similarities = cosine_similarity(text_embedding, label_embeddings)[0] | |
| best_label_idx = similarities.argmax() | |
| best_label = labels[best_label_idx] | |
| logging.info(f"Prediction result: {best_label}") | |
| return {"label": best_label} | |