zohaibterminator's picture
Upload 13 files
8c10e4d verified
raw
history blame
2.91 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import pandas as pd
import numpy as np
from model_load_save import load_model
import dill
def load_preprocessing_components():
with open("encoder.pkl", "rb") as f:
encoder = dill.load(f)
with open("scaler.pkl", "rb") as f:
scaler = dill.load(f)
return encoder, scaler
app = FastAPI()
# Load trained model
model = load_model()
encoder, scaler = load_preprocessing_components()
# Define input schema
class InferenceData(BaseModel):
Age: float
Sex: str
ChestPainType: str
RestingBP: float
Cholesterol: float
FastingBS: int
RestingECG: str
MaxHR: float
ExerciseAngina: str
Oldpeak: float
ST_Slope: str
# Health check endpoint
@app.get("/")
def read_root():
return {"message": "Inference API is up and running"}
# Helper function for preprocessing
def preprocess_data(df: pd.DataFrame) -> np.ndarray:
# Encode categorical variables
encoded = encoder.transform(df[encoder.feature_names_in_])
encoded_df = pd.DataFrame(encoded, columns=encoder.get_feature_names_out(), index=df.index)
# Extracting features
df = pd.concat([df.drop(encoder.feature_names_in_, axis=1), encoded_df], axis=1)
# Combine and scale features
df_selected = pd.concat([df[['Oldpeak', 'MaxHR', 'Age']], df[['ExerciseAngina_Y', 'ST_Slope_Flat', 'ST_Slope_Up']]], axis=1) # directly extracted selected features
# Scale features
df = scaler.transform(df_selected)
return df
# Endpoint for single prediction
@app.post("/predict")
def predict(data: InferenceData):
try:
# Convert input data to DataFrame
df = pd.DataFrame([data.model_dump()])
# Preprocess data
processed_data = preprocess_data(df)
# Make prediction
prediction = model.predict(processed_data)
# Return prediction result
return {"prediction": int(prediction[0])}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
# Endpoint for batch prediction
@app.post("/batch_predict")
def batch_predict(data: List[InferenceData]):
try:
# Convert list of inputs to DataFrame
df = pd.DataFrame([item.model_dump() for item in data])
# Preprocess data
processed_data = preprocess_data(df)
# Make batch predictions
predictions = model.predict(processed_data)
# Format and return predictions
results = [{"input": item.model_dump(), "prediction": int(pred)} for item, pred in zip(data, predictions)]
return {"predictions": results}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during batch prediction: {str(e)}")