|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List
|
|
import pandas as pd
|
|
import numpy as np
|
|
from model_load_save import load_model
|
|
import dill
|
|
|
|
def load_preprocessing_components():
|
|
with open("encoder.pkl", "rb") as f:
|
|
encoder = dill.load(f)
|
|
with open("scaler.pkl", "rb") as f:
|
|
scaler = dill.load(f)
|
|
return encoder, scaler
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
model = load_model()
|
|
encoder, scaler = load_preprocessing_components()
|
|
|
|
|
|
class InferenceData(BaseModel):
|
|
Age: float
|
|
Sex: str
|
|
ChestPainType: str
|
|
RestingBP: float
|
|
Cholesterol: float
|
|
FastingBS: int
|
|
RestingECG: str
|
|
MaxHR: float
|
|
ExerciseAngina: str
|
|
Oldpeak: float
|
|
ST_Slope: str
|
|
|
|
|
|
|
|
@app.get("/")
|
|
def read_root():
|
|
return {"message": "Inference API is up and running"}
|
|
|
|
|
|
|
|
def preprocess_data(df: pd.DataFrame) -> np.ndarray:
|
|
|
|
encoded = encoder.transform(df[encoder.feature_names_in_])
|
|
encoded_df = pd.DataFrame(encoded, columns=encoder.get_feature_names_out(), index=df.index)
|
|
|
|
|
|
df = pd.concat([df.drop(encoder.feature_names_in_, axis=1), encoded_df], axis=1)
|
|
|
|
|
|
df_selected = pd.concat([df[['Oldpeak', 'MaxHR', 'Age']], df[['ExerciseAngina_Y', 'ST_Slope_Flat', 'ST_Slope_Up']]], axis=1)
|
|
|
|
|
|
df = scaler.transform(df_selected)
|
|
|
|
return df
|
|
|
|
|
|
@app.post("/predict")
|
|
def predict(data: InferenceData):
|
|
try:
|
|
|
|
df = pd.DataFrame([data.model_dump()])
|
|
|
|
|
|
processed_data = preprocess_data(df)
|
|
|
|
|
|
prediction = model.predict(processed_data)
|
|
|
|
|
|
return {"prediction": int(prediction[0])}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
|
|
|
|
|
|
|
|
@app.post("/batch_predict")
|
|
def batch_predict(data: List[InferenceData]):
|
|
try:
|
|
|
|
df = pd.DataFrame([item.model_dump() for item in data])
|
|
|
|
|
|
processed_data = preprocess_data(df)
|
|
|
|
|
|
predictions = model.predict(processed_data)
|
|
|
|
|
|
results = [{"input": item.model_dump(), "prediction": int(pred)} for item, pred in zip(data, predictions)]
|
|
return {"predictions": results}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error during batch prediction: {str(e)}") |