extractive_q_n_a / main.py
Sharathhebbar24's picture
Rename app.py to main.py
568a922 verified
import torch
from transformers import (
from transformers import pipeline
from scipy.special import softmax
import pandas as pd
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
model_name = 'deepset/bert-base-uncased-squad2'
pipe = pipeline("question-answering", model=model_name)
# model = BertForQuestionAnswering.from_pretrained(model_name)
# tokenizer = BertTokenizerFast.from_pretrained(model_name)
app = FastAPI()
allow_origins=["*"], # Allow all origins
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
def predict_answer(context, question):
response = pipe({"context": context, "question": question})
return {
"answer": response['answer'],
"score": response['score']
# inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
# with torch.no_grad():
# outputs = model(**inputs)
# start_scores, end_scores = softmax(outputs.start_logits)[0], softmax(outputs.end_logits)[0]
# start_idx = np.argmax(start_scores)
# end_idx = np.argmax(end_scores)
# confidence_score = (start_scores[start_idx] + end_scores[end_idx]) / 2
# answer_ids = inputs.input_ids[0][start_idx: end_idx + 1]
# answer_tokens = tokenizer.convert_ids_to_tokens(answer_ids)
# answer = tokenizer.convert_tokens_to_string(answer_tokens)
# if answer != tokenizer.cls_token:
# return {
# "answer": answer,
# "score": confidence_score
# }
# else:
# return {
# "answer": "No answer found.",
# "score": confidence_score
# }
# Define the request model
class QnARequest(BaseModel):
context: str
question: str
# Define the response model
class QnAResponse(BaseModel):
answer: str
confidence: float
@app.post("/qna", response_model=QnAResponse)
async def extractive_qna(request: QnARequest):
context = request.context
question = request.question
# print(context, question)
if not context or not question:
raise HTTPException(status_code=400, detail="Context and question cannot be empty.")
result = predict_answer(context, question)
return QnAResponse(answer=result["answer"], confidence=result["score"])
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing QnA: {str(e)}")