|
from fastapi import FastAPI |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
app = FastAPI() |
|
|
|
|
|
origins = [ |
|
"http://localhost:3000", |
|
] |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=origins, |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
MODEL_PATH = "./bert-bias-detector/checkpoint-4894" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
class InputText(BaseModel): |
|
text: str |
|
|
|
@app.post("/predict") |
|
async def predict_text(payload: InputText): |
|
inputs = tokenizer(payload.text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
probs = logits.softmax(dim=-1)[0].tolist() |
|
|
|
labels = ["Left", "Center", "Right"] |
|
predicted_label = labels[torch.argmax(logits).item()] |
|
|
|
return { |
|
"bias_scores": probs, |
|
"predicted": predicted_label |
|
} |