zohaibterminator commited on
Commit
9c66090
·
verified ·
1 Parent(s): f2680e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -18
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import requests
3
  import pandas as pd
 
4
 
5
  def load_preprocessing_components():
6
  with open("encoder.pkl", "rb") as f:
@@ -9,6 +10,21 @@ def load_preprocessing_components():
9
  scaler = dill.load(f)
10
  return encoder, scaler
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def preprocess_data(df: pd.DataFrame) -> np.ndarray:
13
  # Encode categorical variables
14
  encoded = encoder.transform(df[encoder.feature_names_in_])
@@ -25,10 +41,11 @@ def preprocess_data(df: pd.DataFrame) -> np.ndarray:
25
 
26
  return df
27
 
28
- def predict(data: Dict()):
 
29
  try:
30
  # Convert input data to DataFrame
31
- df = pd.DataFrame([data.model_dump()])
32
 
33
  # Preprocess data
34
  processed_data = preprocess_data(df)
@@ -59,22 +76,20 @@ exercise_angina = st.selectbox("Exercise-Induced Angina", ["Y", "N"])
59
  oldpeak = st.number_input("Oldpeak", min_value=0.0, max_value=10.0, step=0.1)
60
  st_slope = st.selectbox("ST Slope", ["Up", "Flat", "Down"])
61
 
62
- # Button to submit the form
63
  if st.button("Predict"):
64
- # Prepare the data payload
65
- data = {
66
- "Age": age,
67
- "Sex": sex,
68
- "ChestPainType": chest_pain_type,
69
- "RestingBP": resting_bp,
70
- "Cholesterol": cholesterol,
71
- "FastingBS": fasting_bs,
72
- "RestingECG": resting_ecg,
73
- "MaxHR": max_hr,
74
- "ExerciseAngina": exercise_angina,
75
- "Oldpeak": oldpeak,
76
- "ST_Slope": st_slope
77
- }
78
 
79
  # Send a request to the FastAPI server
80
  response = predict(data)
@@ -87,7 +102,7 @@ if st.button("Predict"):
87
  else:
88
  st.error("Error: Unable to get prediction from API. Please try again later.")
89
 
90
- # Batch Prediction Section
91
  st.subheader("Batch Prediction")
92
  uploaded_file = st.file_uploader("Upload CSV for batch prediction", type="csv")
93
 
 
1
  import streamlit as st
2
  import requests
3
  import pandas as pd
4
+ import numpy as np
5
 
6
  def load_preprocessing_components():
7
  with open("encoder.pkl", "rb") as f:
 
10
  scaler = dill.load(f)
11
  return encoder, scaler
12
 
13
+
14
+ class InferenceData(BaseModel):
15
+ Age: float
16
+ Sex: str
17
+ ChestPainType: str
18
+ RestingBP: float
19
+ Cholesterol: float
20
+ FastingBS: int
21
+ RestingECG: str
22
+ MaxHR: float
23
+ ExerciseAngina: str
24
+ Oldpeak: float
25
+ ST_Slope: str
26
+
27
+
28
  def preprocess_data(df: pd.DataFrame) -> np.ndarray:
29
  # Encode categorical variables
30
  encoded = encoder.transform(df[encoder.feature_names_in_])
 
41
 
42
  return df
43
 
44
+
45
+ def predict(data: InferenceData):
46
  try:
47
  # Convert input data to DataFrame
48
+ df = pd.DataFrame(data.dict())
49
 
50
  # Preprocess data
51
  processed_data = preprocess_data(df)
 
76
  oldpeak = st.number_input("Oldpeak", min_value=0.0, max_value=10.0, step=0.1)
77
  st_slope = st.selectbox("ST Slope", ["Up", "Flat", "Down"])
78
 
 
79
  if st.button("Predict"):
80
+ data = InferenceData(
81
+ Age=age,
82
+ Sex=sex,
83
+ ChestPainType=chest_pain_type,
84
+ RestingBP=resting_bp,
85
+ Cholesterol=cholesterol,
86
+ FastingBS=fasting_bs,
87
+ RestingECG=resting_ecg,
88
+ MaxHR=max_hr,
89
+ ExerciseAngina=exercise_angina,
90
+ Oldpeak=oldpeak,
91
+ ST_Slope=st_slope
92
+ )
 
93
 
94
  # Send a request to the FastAPI server
95
  response = predict(data)
 
102
  else:
103
  st.error("Error: Unable to get prediction from API. Please try again later.")
104
 
105
+
106
  st.subheader("Batch Prediction")
107
  uploaded_file = st.file_uploader("Upload CSV for batch prediction", type="csv")
108