wjnwjn59 commited on
Commit
53227fd
·
1 Parent(s): 6fab39f

modify logic app

Browse files
Files changed (3) hide show
  1. app.py +190 -104
  2. src/heart_disease_core.py +92 -41
  3. vlai_template.py +147 -39
app.py CHANGED
@@ -2,14 +2,15 @@ import os
2
  import gradio as gr
3
  import plotly.graph_objects as go
4
  import pandas as pd
 
5
 
6
  from src.heart_disease_core import (
7
- CLEVELAND_FEATURES_ORDER, TARGET_COL,
8
- load_cleveland_dataframe, fit_all_models, predict_all, example_patient
9
  )
10
 
11
- APP_PRIMARY = "#0F6CBD" # medical calm blue
12
- APP_ACCENT = "#C4314B" # medical alert red
13
  APP_BG = "#F7FAFC"
14
 
15
  STATE = {
@@ -20,21 +21,35 @@ STATE = {
20
 
21
  DATA_PATH = "data/cleveland.csv"
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # -----------------------------
25
- # Startup / init
26
- # -----------------------------
27
  def init_page():
28
- """
29
- Load dataset from disk, fit models once, and return preview + metrics.
30
- Returns plain values (no .update), to maximize Gradio compatibility.
31
- """
32
  if not os.path.exists(DATA_PATH):
33
  msg = f"❌ Dataset not found at '{DATA_PATH}'. Please place Cleveland CSV there."
34
  return msg, pd.DataFrame(), pd.DataFrame()
35
 
36
- raw = pd.read_csv(DATA_PATH)
37
- df = load_cleveland_dataframe(uploaded_df=raw) # cleans, binarizes target
38
 
39
  models, metrics = fit_all_models(df)
40
  STATE["df"] = df
@@ -46,42 +61,60 @@ def init_page():
46
  return msg, head, metrics
47
 
48
 
49
- # -----------------------------
50
- # Helpers
51
- # -----------------------------
52
  def fill_example(idx_text: str):
53
- idx = {
54
- "Example 1 (likely negative)": 0,
55
- "Example 2 (borderline)": 1,
56
- "Example 3 (likely positive)": 2
57
- }[idx_text]
 
 
58
  ex = example_patient(idx)
59
- # Return in the strict feature order so Gradio can assign to outputs 1:1
60
  return [ex[c] for c in CLEVELAND_FEATURES_ORDER]
61
 
62
 
63
  def _bar_for_models(results: dict):
64
  names = list(results.keys())
65
- probs = [results[n]["prob_1"] for n in names]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  fig = go.Figure()
68
- fig.add_bar(x=names, y=probs, text=[f"{p:.2f}" for p in probs], textposition="auto")
69
  fig.update_layout(
70
- title="Model Confidence (P[Heart Disease = 1])",
71
- yaxis_title="Probability",
72
  xaxis_title="Model",
73
  yaxis=dict(range=[0, 1]),
74
  plot_bgcolor="white",
 
 
75
  height=420,
76
  margin=dict(l=30, r=20, t=60, b=40)
77
  )
78
- # Emphasize ensemble bar
79
- colors = ["#9BB8D3"] * len(names)
80
- if "Ensemble (Soft Voting)" in names:
81
- colors[names.index("Ensemble (Soft Voting)")] = APP_ACCENT
82
- elif len(colors) > 0:
83
- colors[-1] = APP_ACCENT
84
- fig.data[0].marker.color = colors
85
  return fig
86
 
87
 
@@ -93,17 +126,43 @@ def run_predict(*vals):
93
  results = predict_all(STATE["models"], input_dict)
94
 
95
  final = results["Ensemble (Soft Voting)"]
96
- title_md = (
97
- f"### 🫀 Cleveland Heart Disease Diagnosis\n"
98
- f"**Ensemble Prediction**: **{'Positive' if final['label'] == 1 else 'Negative'}** \n"
99
- f"**Confidence (P=1)**: `{final['prob_1']:.3f}`"
100
- )
101
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  rows = []
103
  for name, r in results.items():
 
104
  rows.append({
105
  "Model": name,
106
- "Predicted label": "Positive" if r["label"] == 1 else "Negative",
 
107
  "P(No disease)": round(r["prob_0"], 3),
108
  "P(Heart disease)": round(r["prob_1"], 3),
109
  })
@@ -111,80 +170,107 @@ def run_predict(*vals):
111
 
112
  fig = _bar_for_models(results)
113
 
114
- # Return plain values for Markdown, Plot, Markdown, DataFrame
115
- return title_md, fig, "**Per-Model Predictions**", table_df
116
 
117
 
118
- # -----------------------------
119
- # UI (no gr.Box to avoid older-Gradio issues)
120
- # -----------------------------
121
- with gr.Blocks(theme="soft", css=f"""
122
- :root {{
123
- --primary-600: {APP_PRIMARY};
124
- }}
125
- .gradio-container {{ background: {APP_BG}; }}
126
- h1, h2, h3, h4 {{ color: {APP_PRIMARY}; }}
127
- """) as demo:
128
- gr.Markdown("# 🫀 Cleveland Heart Disease Diagnosis (Ensemble Demo)")
129
 
130
- with gr.Row(equal_height=False):
131
  # LEFT: data preview + inputs
132
  with gr.Column(scale=45):
133
- gr.Markdown("### 📁 Dataset & Model Status")
134
- status_md = gr.Markdown("Loading dataset and training models...")
135
- preview = gr.DataFrame(label="Cleveland Preview (first rows)", interactive=False)
136
- metrics_df = gr.DataFrame(label="Validation ROC-AUC (80/20 split)", interactive=False)
137
-
138
- gr.Markdown("### ✍️ Enter Patient Features")
139
- with gr.Row():
140
- age = gr.Number(label="age (years)", value=58)
141
- sex = gr.Dropdown(label="sex (0=female, 1=male)", choices=[0, 1], value=1)
142
- cp = gr.Dropdown(label="cp (chest pain type 0..3)", choices=[0, 1, 2, 3], value=2)
143
- trestbps = gr.Number(label="trestbps (resting BP mmHg)", value=130)
144
-
145
- with gr.Row():
146
- chol = gr.Number(label="chol (serum cholesterol mg/dl)", value=250)
147
- fbs = gr.Dropdown(label="fbs (>120 mg/dl? 1/0)", choices=[0, 1], value=0)
148
- restecg = gr.Dropdown(label="restecg (0..2)", choices=[0, 1, 2], value=1)
149
- thalach = gr.Number(label="thalach (max heart rate)", value=150)
150
-
151
- with gr.Row():
152
- exang = gr.Dropdown(label="exang (exercise angina 1/0)", choices=[0, 1], value=0)
153
- oldpeak = gr.Number(label="oldpeak (ST depression)", value=1.0)
154
- slope = gr.Dropdown(label="slope (0..2)", choices=[0, 1, 2], value=1)
155
- ca = gr.Dropdown(label="ca (major vessels 0..3)", choices=[0, 1, 2, 3], value=0)
156
-
157
- thal = gr.Dropdown(label="thal (1=normal, 2=fixed, 3=reversible)", choices=[1, 2, 3], value=2)
158
-
159
- with gr.Row():
160
- ex_selector = gr.Dropdown(
161
- label="Fill Example",
162
- choices=["Example 1 (likely negative)", "Example 2 (borderline)", "Example 3 (likely positive)"],
163
- value="Example 2 (borderline)"
164
- )
165
- fill_btn = gr.Button("🧪 Use Example")
166
- predict_btn = gr.Button("🔍 Predict", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # RIGHT: outputs
169
  with gr.Column(scale=55):
170
- gr.Markdown("### 📈 Predictions")
171
- title_out = gr.Markdown("Ensemble Prediction will appear here.")
172
- bar_out = gr.Plot(label="Model Confidence")
173
- sub_md = gr.Markdown(visible=False)
174
- table_out = gr.DataFrame(visible=False)
175
-
176
- with gr.Accordion("ℹ️ Notes", open=False):
177
- gr.Markdown(
178
- "- Models are trained once at launch on `data/cleveland.csv` (80/20 split).\n"
179
- "- `target` is binarized automatically (0 = no disease, >0 = disease).\n"
180
- "- Ensemble uses **soft voting** over Decision Tree, k-NN, and Naive Bayes.\n"
181
- "- Educational demo only; **not medical advice**."
182
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  # Bind events
185
  demo.load(fn=init_page, inputs=None, outputs=[status_md, preview, metrics_df])
186
 
187
- fill_btn.click(
 
188
  fn=fill_example,
189
  inputs=[ex_selector],
190
  outputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal]
@@ -193,8 +279,8 @@ h1, h2, h3, h4 {{ color: {APP_PRIMARY}; }}
193
  predict_btn.click(
194
  fn=run_predict,
195
  inputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal],
196
- outputs=[title_out, bar_out, sub_md, table_out]
197
  )
198
 
199
  if __name__ == "__main__":
200
- demo.launch()
 
2
  import gradio as gr
3
  import plotly.graph_objects as go
4
  import pandas as pd
5
+ import vlai_template
6
 
7
  from src.heart_disease_core import (
8
+ CLEVELAND_FEATURES_ORDER,
9
+ load_cleveland_dataframe, fit_all_models, predict_all, example_patient, get_example_labels
10
  )
11
 
12
+ APP_PRIMARY = vlai_template.PRIMARY_COLOR
13
+ APP_ACCENT = vlai_template.ACCENT_COLOR
14
  APP_BG = "#F7FAFC"
15
 
16
  STATE = {
 
21
 
22
  DATA_PATH = "data/cleveland.csv"
23
 
24
+ vlai_template.set_meta(
25
+ project_name="Heart Disease Diagnosis Project",
26
+ year="2025",
27
+ module="03",
28
+ description="Predict heart disease risk from patient data with ML models trained on the Cleveland dataset.",
29
+ meta_items=[
30
+ ("Dataset", "Cleveland Heart Disease"),
31
+ ("Models", "Decision Tree, k-NN, Naive Bayes"),
32
+ ("Ensemble", "Soft Voting"),
33
+ ],
34
+ )
35
+
36
+ force_light_theme_js = """
37
+ () => {
38
+ const params = new URLSearchParams(window.location.search);
39
+ if (!params.has('__theme')) {
40
+ params.set('__theme', 'light');
41
+ window.location.search = params.toString();
42
+ }
43
+ }
44
+ """
45
 
 
 
 
46
  def init_page():
47
+ """Load dataset, train models, and return status, preview, metrics."""
 
 
 
48
  if not os.path.exists(DATA_PATH):
49
  msg = f"❌ Dataset not found at '{DATA_PATH}'. Please place Cleveland CSV there."
50
  return msg, pd.DataFrame(), pd.DataFrame()
51
 
52
+ df = load_cleveland_dataframe(file_path=DATA_PATH)
 
53
 
54
  models, metrics = fit_all_models(df)
55
  STATE["df"] = df
 
61
  return msg, head, metrics
62
 
63
 
 
 
 
64
  def fill_example(idx_text: str):
65
+ import re
66
+ match = re.search(r'Example (\d+)', idx_text)
67
+ if match:
68
+ idx = int(match.group(1)) - 1
69
+ else:
70
+ idx = 1
71
+
72
  ex = example_patient(idx)
 
73
  return [ex[c] for c in CLEVELAND_FEATURES_ORDER]
74
 
75
 
76
  def _bar_for_models(results: dict):
77
  names = list(results.keys())
78
+ confidences = []
79
+ predictions_text = []
80
+ bar_colors = []
81
+ line_colors = []
82
+ line_widths = []
83
+
84
+ for n in names:
85
+ r = results[n]
86
+ if r["label"] == 1:
87
+ confidences.append(r["prob_1"])
88
+ predictions_text.append("🫀 Heart Disease")
89
+ bar_colors.append("#C4314B")
90
+ else:
91
+ confidences.append(r["prob_0"])
92
+ predictions_text.append("✅ No Heart Disease")
93
+ bar_colors.append("#2E7D32")
94
+ line_colors.append("rgba(0,0,0,0.15)")
95
+ line_widths.append(1.0)
96
+
97
+ if "Ensemble (Soft Voting)" in names:
98
+ idx = names.index("Ensemble (Soft Voting)")
99
+ line_colors[idx] = "#000000"
100
+ line_widths[idx] = 2.5
101
 
102
  fig = go.Figure()
103
+ fig.add_bar(x=names, y=confidences, text=predictions_text, textposition="auto")
104
  fig.update_layout(
105
+ title="Model Predictions",
106
+ yaxis_title="Prediction Confidence",
107
  xaxis_title="Model",
108
  yaxis=dict(range=[0, 1]),
109
  plot_bgcolor="white",
110
+ paper_bgcolor="white",
111
+ font=dict(color="black", size=12),
112
  height=420,
113
  margin=dict(l=30, r=20, t=60, b=40)
114
  )
115
+ fig.data[0].marker.color = bar_colors
116
+ fig.data[0].marker.line.color = line_colors
117
+ fig.data[0].marker.line.width = line_widths
 
 
 
 
118
  return fig
119
 
120
 
 
126
  results = predict_all(STATE["models"], input_dict)
127
 
128
  final = results["Ensemble (Soft Voting)"]
129
+ ensemble_color = "#C4314B" if final["label"] == 1 else "#2E7D32"
130
+ ensemble_prediction = "🫀 **Heart Disease Detected**" if final["label"] == 1 else "✅ **No Heart Disease**"
131
+
132
+ ensemble_md = f"""
133
+ <div style=\"border: 3px solid {ensemble_color}; border-radius: 10px; padding: 20px; margin: 15px 0; background: white;\">
134
+ <h3 style=\"margin: 0 0 15px 0; color: {ensemble_color};\">🎯 Ensemble Prediction (Final Result)</h3>
135
+ <p style=\"margin: 10px 0; font-size: 18px; color: black;\"><strong>{ensemble_prediction}</strong></p>
136
+ <p style=\"margin: 5px 0; font-size: 16px; color: black;\"><strong>Confidence:</strong> {final['prob_1']:.1%}</p>
137
+ </div>
138
+ """
139
+
140
+ model_predictions = []
141
+ for name, r in results.items():
142
+ prediction_text = "🫀 **Heart Disease Detected**" if r["label"] == 1 else "✅ **No Heart Disease**"
143
+ confidence = r["prob_1"] if r["label"] == 1 else r["prob_0"]
144
+ color = "#C4314B" if r["label"] == 1 else "#2E7D32"
145
+
146
+ model_predictions.append(f"""
147
+ <div style=\"border: 2px solid {color}; border-radius: 8px; padding: 15px; margin: 10px 0; background: white;\">
148
+ <h4 style=\"margin: 0 0 10px 0; color: {color};\">{name}</h4>
149
+ <p style=\"margin: 5px 0; font-size: 16px; color: black;\"><strong>Prediction:</strong> {prediction_text}</p>
150
+ <p style=\"margin: 5px 0; font-size: 14px; color: black;\"><strong>Confidence:</strong> {confidence:.1%}</p>
151
+ <p style=\"margin: 5px 0; font-size: 12px; color: #666;\">
152
+ P(No disease): {r['prob_0']:.3f} | P(Heart disease): {r['prob_1']:.3f}
153
+ </p>
154
+ </div>
155
+ """)
156
+
157
+ all_predictions = "\n".join(model_predictions)
158
+
159
  rows = []
160
  for name, r in results.items():
161
+ confidence = r["prob_1"] if r["label"] == 1 else r["prob_0"]
162
  rows.append({
163
  "Model": name,
164
+ "Prediction": "Heart Disease" if r["label"] == 1 else "No Heart Disease",
165
+ "Confidence": f"{confidence:.1%}",
166
  "P(No disease)": round(r["prob_0"], 3),
167
  "P(Heart disease)": round(r["prob_1"], 3),
168
  })
 
170
 
171
  fig = _bar_for_models(results)
172
 
173
+ return fig, "\n".join(model_predictions), table_df
 
174
 
175
 
176
+ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo:
177
+ vlai_template.create_header()
178
+ gr.HTML(vlai_template.render_info_card(icon="🫀", title="About this demo"))
179
+ gr.HTML(vlai_template.render_disclaimer(
180
+ text=(
181
+ "This interactive heart disease prediction demo is provided strictly for educational purposes. "
182
+ "It is not intended for clinical use and must not be relied upon for medical advice, diagnosis, "
183
+ "treatment, or decision-making. Always consult a qualified healthcare professional."
184
+ )
185
+ ))
186
+ gr.Markdown("### 🫀 **How to Use**: Enter patient features → Run prediction → View ensemble results!")
187
 
188
+ with gr.Row(equal_height=False, variant="panel"):
189
  # LEFT: data preview + inputs
190
  with gr.Column(scale=45):
191
+ with gr.Accordion("📁 Dataset & Model Status", open=True):
192
+ status_md = gr.Markdown("Loading dataset and training models...")
193
+ preview = gr.DataFrame(label="Cleveland Preview (first rows)", interactive=False)
194
+ metrics_df = gr.DataFrame(label="Validation Metrics (80/20 split)", interactive=False)
195
+
196
+ with gr.Accordion("✍️ Enter Patient Features", open=True):
197
+ with gr.Row():
198
+ age = gr.Number(label="age (years)", value=58)
199
+ sex = gr.Dropdown(label="sex (0=female, 1=male)", choices=[0, 1], value=1)
200
+ cp = gr.Dropdown(label="cp (chest pain type 1..4)", choices=[1, 2, 3, 4], value=2)
201
+ trestbps = gr.Number(label="trestbps (resting BP mmHg)", value=130)
202
+
203
+ with gr.Row():
204
+ chol = gr.Number(label="chol (serum cholesterol mg/dl)", value=250)
205
+ fbs = gr.Dropdown(label="fbs (>120 mg/dl? 1/0)", choices=[0, 1], value=0)
206
+ restecg = gr.Dropdown(label="restecg (0..2)", choices=[0, 1, 2], value=1)
207
+ thalach = gr.Number(label="thalach (max heart rate)", value=150)
208
+
209
+ with gr.Row():
210
+ exang = gr.Dropdown(label="exang (exercise angina 1/0)", choices=[0, 1], value=0)
211
+ oldpeak = gr.Number(label="oldpeak (ST depression)", value=1.0)
212
+ slope = gr.Dropdown(label="slope (1..3)", choices=[1, 2, 3], value=1)
213
+ ca = gr.Dropdown(label="ca (major vessels 0..3)", choices=[0, 1, 2, 3], value=0)
214
+
215
+ thal = gr.Dropdown(label="thal (3=normal, 6=fixed, 7=reversible)", choices=[3, 6, 7], value=3)
216
+
217
+ with gr.Row():
218
+ # Get actual labels from the dataset - only 2 examples
219
+ try:
220
+ labels = get_example_labels()
221
+ choices = []
222
+ # Only use first two examples: one no disease, one disease
223
+ for i in range(min(2, len(labels))):
224
+ label_text = "No Heart Disease" if labels[i] == 0 else "Heart Disease"
225
+ choices.append(f"Example {i+1} ({label_text})")
226
+ default_choice = choices[0] if choices else "Example 1"
227
+ except:
228
+ choices = ["Example 1 (No Heart Disease)", "Example 2 (Heart Disease)"]
229
+ default_choice = "Example 1 (No Heart Disease)"
230
+
231
+ ex_selector = gr.Dropdown(
232
+ label="Select Example Patient",
233
+ choices=choices,
234
+ value=default_choice
235
+ )
236
+ predict_btn = gr.Button("🔍 Predict", variant="primary")
237
 
238
  # RIGHT: outputs
239
  with gr.Column(scale=55):
240
+ gr.Markdown("### 📈 Model Predictions")
241
+ bar_out = gr.Plot(label="Model Predictions Overview")
242
+ sub_md = gr.Markdown("**Individual Model Results**")
243
+ table_out = gr.DataFrame(label="All Model Predictions", interactive=False)
244
+
245
+ gr.Markdown("""
246
+ ## 📋 **Notes**
247
+
248
+ - **Models are trained once at launch** on `data/cleveland.csv` (80/20 split).
249
+ - **Target is binarized automatically** (0 = no disease, >0 = disease).
250
+ - **Ensemble uses soft voting** over Decision Tree, k-NN, and Naive Bayes.
251
+ - **Feature descriptions**:
252
+ - `age`: Patient age in years
253
+ - `sex`: Gender (0=female, 1=male)
254
+ - `cp`: Chest pain type (1-4)
255
+ - `trestbps`: Resting blood pressure (mmHg)
256
+ - `chol`: Serum cholesterol (mg/dl)
257
+ - `fbs`: Fasting blood sugar >120 mg/dl (1=true, 0=false)
258
+ - `restecg`: Resting ECG results (0-2)
259
+ - `thalach`: Maximum heart rate achieved
260
+ - `exang`: Exercise induced angina (1=yes, 0=no)
261
+ - `oldpeak`: ST depression induced by exercise
262
+ - `slope`: Slope of peak exercise ST segment (1-3)
263
+ - `ca`: Number of major vessels colored by fluoroscopy (0-3)
264
+ - `thal`: Thalassemia (3=normal, 6=fixed defect, 7=reversible defect)
265
+ """)
266
+
267
+ vlai_template.create_footer()
268
 
269
  # Bind events
270
  demo.load(fn=init_page, inputs=None, outputs=[status_md, preview, metrics_df])
271
 
272
+ # Auto-fill when example is selected
273
+ ex_selector.change(
274
  fn=fill_example,
275
  inputs=[ex_selector],
276
  outputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal]
 
279
  predict_btn.click(
280
  fn=run_predict,
281
  inputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal],
282
+ outputs=[bar_out, sub_md, table_out]
283
  )
284
 
285
  if __name__ == "__main__":
286
+ demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"])
src/heart_disease_core.py CHANGED
@@ -1,4 +1,3 @@
1
- # src/heart_disease_core.py
2
  import os
3
  import numpy as np
4
  import pandas as pd
@@ -9,7 +8,7 @@ from sklearn.preprocessing import OneHotEncoder
9
  from sklearn.compose import ColumnTransformer
10
  from sklearn.pipeline import Pipeline
11
  from sklearn.impute import SimpleImputer
12
- from sklearn.metrics import roc_auc_score
13
  from sklearn.tree import DecisionTreeClassifier
14
  from sklearn.neighbors import KNeighborsClassifier
15
  from sklearn.naive_bayes import GaussianNB
@@ -20,48 +19,41 @@ CLEVELAND_FEATURES_ORDER: List[str] = [
20
  "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
21
  "thalach", "exang", "oldpeak", "slope", "ca", "thal"
22
  ]
23
- TARGET_COL = "target" # 0: no disease, 1: disease (we binarize if needed)
24
 
25
  CATEGORICAL_CHOICES = {
26
- "sex": [0, 1], # 0: female, 1: male
27
- "cp": [0, 1, 2, 3], # chest pain type
28
- "fbs": [0, 1], # fasting blood sugar > 120 mg/dl (1 true, 0 false)
29
- "restecg": [0, 1, 2], # resting ECG results
30
- "exang": [0, 1], # exercise-induced angina
31
- "slope": [0, 1, 2], # slope of ST
32
- "ca": [0, 1, 2, 3], # number of major vessels (0-3) colored by fluoroscopy
33
- "thal": [1, 2, 3], # 1: normal, 2: fixed defect, 3: reversible defect (commonly 3/6/7 variants exist; we standardize)
34
  }
35
 
36
  NUMERIC_COLS = ["age", "trestbps", "chol", "thalach", "oldpeak"]
37
  CATEGORICAL_COLS = ["sex", "cp", "fbs", "restecg", "exang", "slope", "ca", "thal"]
38
 
39
  def _coerce_and_clean(df: pd.DataFrame) -> pd.DataFrame:
40
- """Clean '?' and cast numeric; keep only known columns if present."""
41
  df = df.copy()
42
- # Standardize columns if they are present with any case
43
  colmap = {c.lower(): c for c in df.columns}
44
  for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
45
  if col not in df.columns and col in colmap:
46
- df[col] = df.pop(colmap[col]) # normalize name
47
 
48
- # Replace '?' with NaN and cast
49
  for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
50
  if col in df.columns:
51
  df[col] = pd.to_numeric(df[col].replace("?", np.nan), errors="coerce")
52
 
53
- # Binarize target if it appears as 0..4 (UCI often uses 0 vs 1..4 disease)
54
  if TARGET_COL in df.columns:
55
  df[TARGET_COL] = (df[TARGET_COL] > 0).astype(int)
56
 
57
  return df
58
 
59
  def load_cleveland_dataframe(file_path: Optional[str] = None, uploaded_df: Optional[pd.DataFrame] = None) -> pd.DataFrame:
60
- """
61
- Load the Cleveland Heart Disease dataset.
62
- Priority: uploaded_df > file_path > raise.
63
- Expect columns CLEVELAND_FEATURES_ORDER + TARGET_COL.
64
- """
65
  if uploaded_df is not None:
66
  df = _coerce_and_clean(uploaded_df)
67
  missing = [c for c in CLEVELAND_FEATURES_ORDER + [TARGET_COL] if c not in df.columns]
@@ -71,7 +63,19 @@ def load_cleveland_dataframe(file_path: Optional[str] = None, uploaded_df: Optio
71
 
72
  if file_path is not None and os.path.exists(file_path):
73
  if file_path.endswith(".csv"):
74
- df = pd.read_csv(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
75
  else:
76
  df = pd.read_excel(file_path)
77
  df = _coerce_and_clean(df)
@@ -171,14 +175,29 @@ def fit_all_models(df: pd.DataFrame, test_size: float = 0.2, random_state: int =
171
 
172
  for name, pipe in models.items():
173
  pipe.fit(X_tr, y_tr)
 
 
174
  if hasattr(pipe, "predict_proba"):
175
  proba = pipe.predict_proba(X_te)[:, 1]
176
  auc = roc_auc_score(y_te, proba)
177
  else:
178
- # Fallback if any (unlikely here)
179
- pred = pipe.predict(X_te)
180
- auc = roc_auc_score(y_te, pred)
181
- metrics.append({"model": name, "ROC-AUC": round(float(auc), 4)})
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  metrics_df = pd.DataFrame(metrics).sort_values("ROC-AUC", ascending=False, ignore_index=True)
184
  return models, metrics_df
@@ -210,19 +229,51 @@ def predict_all(models: Dict[str, Pipeline], input_dict: Dict[str, float]) -> Di
210
 
211
  def example_patient(index: int = 0) -> Dict[str, float]:
212
  """
213
- A few realistic examples pulled from common Cleveland-like ranges.
214
- You can add more patterns for quick testing.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  """
216
- examples = [
217
- # Likely negative (no disease)
218
- dict(age=45, sex=0, cp=0, trestbps=120, chol=230, fbs=0, restecg=1,
219
- thalach=168, exang=0, oldpeak=0.0, slope=2, ca=0, thal=2),
220
- # Borderline
221
- dict(age=58, sex=1, cp=2, trestbps=138, chol=250, fbs=0, restecg=0,
222
- thalach=150, exang=0, oldpeak=1.0, slope=1, ca=1, thal=2),
223
- # Likely positive (disease)
224
- dict(age=63, sex=1, cp=3, trestbps=145, chol=320, fbs=1, restecg=2,
225
- thalach=130, exang=1, oldpeak=2.8, slope=0, ca=2, thal=3),
226
- ]
227
- index = max(0, min(index, len(examples) - 1))
228
- return examples[index]
 
 
1
  import os
2
  import numpy as np
3
  import pandas as pd
 
8
  from sklearn.compose import ColumnTransformer
9
  from sklearn.pipeline import Pipeline
10
  from sklearn.impute import SimpleImputer
11
+ from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
12
  from sklearn.tree import DecisionTreeClassifier
13
  from sklearn.neighbors import KNeighborsClassifier
14
  from sklearn.naive_bayes import GaussianNB
 
19
  "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
20
  "thalach", "exang", "oldpeak", "slope", "ca", "thal"
21
  ]
22
+ TARGET_COL = "target"
23
 
24
  CATEGORICAL_CHOICES = {
25
+ "sex": [0, 1],
26
+ "cp": [0, 1, 2, 3],
27
+ "fbs": [0, 1],
28
+ "restecg": [0, 1, 2],
29
+ "exang": [0, 1],
30
+ "slope": [0, 1, 2],
31
+ "ca": [0, 1, 2, 3],
32
+ "thal": [1, 2, 3],
33
  }
34
 
35
  NUMERIC_COLS = ["age", "trestbps", "chol", "thalach", "oldpeak"]
36
  CATEGORICAL_COLS = ["sex", "cp", "fbs", "restecg", "exang", "slope", "ca", "thal"]
37
 
38
  def _coerce_and_clean(df: pd.DataFrame) -> pd.DataFrame:
39
+ """Clean '?', cast numerics, normalize column names, and binarize target."""
40
  df = df.copy()
 
41
  colmap = {c.lower(): c for c in df.columns}
42
  for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
43
  if col not in df.columns and col in colmap:
44
+ df[col] = df.pop(colmap[col])
45
 
 
46
  for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
47
  if col in df.columns:
48
  df[col] = pd.to_numeric(df[col].replace("?", np.nan), errors="coerce")
49
 
 
50
  if TARGET_COL in df.columns:
51
  df[TARGET_COL] = (df[TARGET_COL] > 0).astype(int)
52
 
53
  return df
54
 
55
  def load_cleveland_dataframe(file_path: Optional[str] = None, uploaded_df: Optional[pd.DataFrame] = None) -> pd.DataFrame:
56
+ """Load Cleveland dataset from upload or file path and ensure schema."""
 
 
 
 
57
  if uploaded_df is not None:
58
  df = _coerce_and_clean(uploaded_df)
59
  missing = [c for c in CLEVELAND_FEATURES_ORDER + [TARGET_COL] if c not in df.columns]
 
63
 
64
  if file_path is not None and os.path.exists(file_path):
65
  if file_path.endswith(".csv"):
66
+ # Try reading with headers first; fall back to no header
67
+ try:
68
+ df = pd.read_csv(file_path)
69
+ if len(df.columns) == len(CLEVELAND_FEATURES_ORDER) + 1: # +1 for target
70
+ first_row_numeric = all(pd.to_numeric(df.iloc[0], errors='coerce').notna())
71
+ if first_row_numeric:
72
+ # Re-read without headers and assign names
73
+ df = pd.read_csv(file_path, header=None)
74
+ df.columns = CLEVELAND_FEATURES_ORDER + [TARGET_COL]
75
+ except:
76
+ # Fallback: read without headers
77
+ df = pd.read_csv(file_path, header=None)
78
+ df.columns = CLEVELAND_FEATURES_ORDER + [TARGET_COL]
79
  else:
80
  df = pd.read_excel(file_path)
81
  df = _coerce_and_clean(df)
 
175
 
176
  for name, pipe in models.items():
177
  pipe.fit(X_tr, y_tr)
178
+ # Predictions and probabilities
179
+ y_pred = pipe.predict(X_te)
180
  if hasattr(pipe, "predict_proba"):
181
  proba = pipe.predict_proba(X_te)[:, 1]
182
  auc = roc_auc_score(y_te, proba)
183
  else:
184
+ # Fallback if probabilities are not available
185
+ proba = None
186
+ auc = roc_auc_score(y_te, y_pred)
187
+
188
+ acc = accuracy_score(y_te, y_pred)
189
+ prec = precision_score(y_te, y_pred, zero_division=0)
190
+ rec = recall_score(y_te, y_pred, zero_division=0)
191
+ f1 = f1_score(y_te, y_pred, zero_division=0)
192
+
193
+ metrics.append({
194
+ "model": name,
195
+ "ROC-AUC": round(float(auc), 4),
196
+ "Accuracy": round(float(acc), 4),
197
+ "Precision": round(float(prec), 4),
198
+ "Recall": round(float(rec), 4),
199
+ "F1": round(float(f1), 4),
200
+ })
201
 
202
  metrics_df = pd.DataFrame(metrics).sort_values("ROC-AUC", ascending=False, ignore_index=True)
203
  return models, metrics_df
 
229
 
230
  def example_patient(index: int = 0) -> Dict[str, float]:
231
  """
232
+ Get example patients with specific features provided by user.
233
+ """
234
+ # Example 1: No heart disease (37,1,3,130,250,0,0,187,0,3.5,3,0,3,0)
235
+ # Example 2: Heart disease (56,1,3,130,256,1,2,142,1,0.6,2,1,6,2)
236
+
237
+ if index == 0:
238
+ # No heart disease example
239
+ return {
240
+ "age": 37.0,
241
+ "sex": 1.0,
242
+ "cp": 3.0,
243
+ "trestbps": 130.0,
244
+ "chol": 250.0,
245
+ "fbs": 0.0,
246
+ "restecg": 0.0,
247
+ "thalach": 187.0,
248
+ "exang": 0.0,
249
+ "oldpeak": 3.5,
250
+ "slope": 3.0,
251
+ "ca": 0.0,
252
+ "thal": 3.0
253
+ }
254
+ else:
255
+ # Heart disease example
256
+ return {
257
+ "age": 56.0,
258
+ "sex": 1.0,
259
+ "cp": 3.0,
260
+ "trestbps": 130.0,
261
+ "chol": 256.0,
262
+ "fbs": 1.0,
263
+ "restecg": 2.0,
264
+ "thalach": 142.0,
265
+ "exang": 1.0,
266
+ "oldpeak": 0.6,
267
+ "slope": 2.0,
268
+ "ca": 1.0,
269
+ "thal": 6.0
270
+ }
271
+
272
+ def get_example_labels() -> List[int]:
273
+ """
274
+ Get the labels for the example patients to display in the UI.
275
+ Returns list of labels for the specific examples provided.
276
  """
277
+ # Example 1: No heart disease (target = 0)
278
+ # Example 2: Heart disease (target = 2, binarized to 1)
279
+ return [0, 1] # First example: no disease, second example: heart disease
 
 
 
 
 
 
 
 
 
 
vlai_template.py CHANGED
@@ -1,11 +1,73 @@
1
  import os, base64
2
  import gradio as gr
3
 
4
-
5
- PROJECT_NAME = "Decision Tree Demo"
 
 
 
 
 
 
 
 
 
6
  AIO_YEAR = "2025"
7
- AIO_MODULE = "03"
8
- # END
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def image_to_base64(image_path: str):
@@ -28,7 +90,7 @@ def create_header():
28
  gr.HTML(f"""
29
  <div style="display:flex;justify-content:flex-start;align-items:center;gap:30px;">
30
  <div>
31
- <h1 style="margin-bottom:0; color: #2E7D32; font-size: 2.5em; font-weight: bold;"> {PROJECT_NAME} </h1>
32
  <h3 style="color: #888; font-style: italic"> AIO{AIO_YEAR}: Module {AIO_MODULE}. </h3>
33
  </div>
34
  </div>
@@ -53,90 +115,136 @@ def create_footer():
53
  """
54
  return gr.HTML(footer_html)
55
 
56
- custom_css = """
 
 
57
 
58
- .gradio-container {
59
  min-height: 100vh !important;
60
  width: 100vw !important;
61
  margin: 0 !important;
62
  padding: 0px !important;
63
- background: linear-gradient(135deg, #E8F5E8 0%, #D4E6D4 50%, #A8D8A8 100%);
64
  background-size: 600% 600%;
65
  animation: gradientBG 7s ease infinite;
66
- }
 
 
 
 
 
 
 
67
 
68
- @keyframes gradientBG {
69
- 0% {background-position: 0% 50%;}
70
- 50% {background-position: 100% 50%;}
71
- 100% {background-position: 0% 50%;}
72
- }
73
 
74
  /* Minimize spacing and padding */
75
- .content-wrap {
76
  padding: 2px !important;
77
  margin: 0 !important;
78
- }
79
 
80
  /* Reduce component spacing */
81
- .gr-row {
82
  gap: 5px !important;
83
  margin: 2px 0 !important;
84
- }
85
 
86
- .gr-column {
87
  gap: 4px !important;
88
  padding: 4px !important;
89
- }
90
 
91
  /* Accordion optimization */
92
- .gr-accordion {
93
  margin: 4px 0 !important;
94
- }
95
 
96
- .gr-accordion .gr-accordion-content {
97
  padding: 2px !important;
98
- }
99
 
100
  /* Form elements spacing */
101
- .gr-form {
102
  gap: 2px !important;
103
- }
104
 
105
  /* Button styling */
106
- .gr-button {
107
  margin: 2px 0 !important;
108
- }
109
 
110
  /* DataFrame optimization */
111
- .gr-dataframe {
112
  margin: 4px 0 !important;
113
- }
114
 
115
  /* Remove horizontal scroll from data preview */
116
- .gr-dataframe .wrap {
117
  overflow-x: auto !important;
118
  max-width: 100% !important;
119
- }
120
 
121
  /* Plot optimization */
122
- .gr-plot {
123
  margin: 4px 0 !important;
124
- }
125
 
126
  /* Reduce markdown margins */
127
- .gr-markdown {
128
  margin: 2px 0 !important;
129
- }
130
 
131
  /* Footer positioning */
132
- .sticky-footer {
133
  position: fixed;
134
  bottom: 0px;
135
  left: 0;
136
  width: 100%;
137
- background: #E8F5E8;
138
  padding: 6px !important;
139
  box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
140
  z-index: 1000;
141
- }
142
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os, base64
2
  import gradio as gr
3
 
4
+ # Theming (can be overridden by the host app)
5
+ PRIMARY_COLOR = "#0F6CBD" # medical calm blue
6
+ ACCENT_COLOR = "#C4314B" # medical alert red
7
+ SUCCESS_COLOR = "#2E7D32" # positive/ok
8
+ BG1 = "#F0F7FF"
9
+ BG2 = "#E8F0FA"
10
+ BG3 = "#DDE7F8"
11
+ FONT_FAMILY = "'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, 'Noto Sans', 'Liberation Sans', sans-serif"
12
+
13
+ # App metadata (overridable)
14
+ PROJECT_NAME = "Demo Project"
15
  AIO_YEAR = "2025"
16
+ AIO_MODULE = "00"
17
+ PROJECT_DESCRIPTION = ""
18
+ META_INFO = [] # list of (label, value)
19
+
20
+ def set_colors(primary: str = None, accent: str = None, bg1: str = None, bg2: str = None, bg3: str = None):
21
+ """Allow host app to set theme colors dynamically."""
22
+ global PRIMARY_COLOR, ACCENT_COLOR, BG1, BG2, BG3, custom_css
23
+ if primary:
24
+ PRIMARY_COLOR = primary
25
+ if accent:
26
+ ACCENT_COLOR = accent
27
+ if bg1:
28
+ BG1 = bg1
29
+ if bg2:
30
+ BG2 = bg2
31
+ if bg3:
32
+ BG3 = bg3
33
+ # Rebuild CSS with new colors
34
+ custom_css = _build_custom_css()
35
+
36
+ def set_font(font_family: str):
37
+ """Allow host app to set a custom font stack (e.g., 'Inter', system fallbacks)."""
38
+ global FONT_FAMILY, custom_css
39
+ if font_family and isinstance(font_family, str):
40
+ FONT_FAMILY = font_family
41
+ custom_css = _build_custom_css()
42
+
43
+ def set_meta(project_name: str = None, year: str = None, module: str = None, description: str = None, meta_items: list = None):
44
+ """Set project metadata used across the header and info sections."""
45
+ global PROJECT_NAME, AIO_YEAR, AIO_MODULE, PROJECT_DESCRIPTION, META_INFO
46
+ if project_name is not None:
47
+ PROJECT_NAME = project_name
48
+ if year is not None:
49
+ AIO_YEAR = year
50
+ if module is not None:
51
+ AIO_MODULE = module
52
+ if description is not None:
53
+ PROJECT_DESCRIPTION = description
54
+ if meta_items is not None:
55
+ META_INFO = meta_items
56
+
57
+ def configure(project_name: str = None, year: str = None, module: str = None, description: str = None,
58
+ colors: dict = None, font_family: str = None, meta_items: list = None):
59
+ """One-call configuration for meta, theme, and font."""
60
+ if colors:
61
+ set_colors(
62
+ primary=colors.get("primary"),
63
+ accent=colors.get("accent"),
64
+ bg1=colors.get("bg1"),
65
+ bg2=colors.get("bg2"),
66
+ bg3=colors.get("bg3"),
67
+ )
68
+ if font_family:
69
+ set_font(font_family)
70
+ set_meta(project_name, year, module, description, meta_items)
71
 
72
 
73
  def image_to_base64(image_path: str):
 
90
  gr.HTML(f"""
91
  <div style="display:flex;justify-content:flex-start;align-items:center;gap:30px;">
92
  <div>
93
+ <h1 style="margin-bottom:0; color: {PRIMARY_COLOR}; font-size: 2.5em; font-weight: bold;"> {PROJECT_NAME} </h1>
94
  <h3 style="color: #888; font-style: italic"> AIO{AIO_YEAR}: Module {AIO_MODULE}. </h3>
95
  </div>
96
  </div>
 
115
  """
116
  return gr.HTML(footer_html)
117
 
118
+ def _build_custom_css() -> str:
119
+ return f"""
120
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
121
 
122
+ .gradio-container {{
123
  min-height: 100vh !important;
124
  width: 100vw !important;
125
  margin: 0 !important;
126
  padding: 0px !important;
127
+ background: linear-gradient(135deg, {BG1} 0%, {BG2} 50%, {BG3} 100%);
128
  background-size: 600% 600%;
129
  animation: gradientBG 7s ease infinite;
130
+ }}
131
+
132
+ /* Global font setup */
133
+ body, .gradio-container, .gr-block, .gr-markdown, .gr-button, .gr-input,
134
+ .gr-dropdown, .gr-number, .gr-plot, .gr-dataframe, .gr-accordion, .gr-form,
135
+ .gr-textbox, .gr-html, table, th, td, label, h1, h2, h3, h4, h5, h6, p, span, div {{
136
+ font-family: {FONT_FAMILY} !important;
137
+ }}
138
 
139
+ @keyframes gradientBG {{
140
+ 0% {{background-position: 0% 50%;}}
141
+ 50% {{background-position: 100% 50%;}}
142
+ 100% {{background-position: 0% 50%;}}
143
+ }}
144
 
145
  /* Minimize spacing and padding */
146
+ .content-wrap {{
147
  padding: 2px !important;
148
  margin: 0 !important;
149
+ }}
150
 
151
  /* Reduce component spacing */
152
+ .gr-row {{
153
  gap: 5px !important;
154
  margin: 2px 0 !important;
155
+ }}
156
 
157
+ .gr-column {{
158
  gap: 4px !important;
159
  padding: 4px !important;
160
+ }}
161
 
162
  /* Accordion optimization */
163
+ .gr-accordion {{
164
  margin: 4px 0 !important;
165
+ }}
166
 
167
+ .gr-accordion .gr-accordion-content {{
168
  padding: 2px !important;
169
+ }}
170
 
171
  /* Form elements spacing */
172
+ .gr-form {{
173
  gap: 2px !important;
174
+ }}
175
 
176
  /* Button styling */
177
+ .gr-button {{
178
  margin: 2px 0 !important;
179
+ }}
180
 
181
  /* DataFrame optimization */
182
+ .gr-dataframe {{
183
  margin: 4px 0 !important;
184
+ }}
185
 
186
  /* Remove horizontal scroll from data preview */
187
+ .gr-dataframe .wrap {{
188
  overflow-x: auto !important;
189
  max-width: 100% !important;
190
+ }}
191
 
192
  /* Plot optimization */
193
+ .gr-plot {{
194
  margin: 4px 0 !important;
195
+ }}
196
 
197
  /* Reduce markdown margins */
198
+ .gr-markdown {{
199
  margin: 2px 0 !important;
200
+ }}
201
 
202
  /* Footer positioning */
203
+ .sticky-footer {{
204
  position: fixed;
205
  bottom: 0px;
206
  left: 0;
207
  width: 100%;
208
+ background: {BG1};
209
  padding: 6px !important;
210
  box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
211
  z-index: 1000;
212
+ }}
213
  """
214
+
215
+ # Initialize CSS using defaults
216
+ custom_css = _build_custom_css()
217
+
218
+ def render_info_card(description: str = None, meta_items: list = None, icon: str = "🧠", title: str = "About this demo") -> str:
219
+ desc = description if description is not None else PROJECT_DESCRIPTION
220
+ items = meta_items if meta_items is not None else META_INFO
221
+ meta_html = " · ".join([f"<span><strong>{k}</strong>: {v}</span>" for k, v in items]) if items else ""
222
+ return f"""
223
+ <div style="margin: 8px 0 8px 0;">
224
+ <div style="background:#F5F9FF;border-left:6px solid {PRIMARY_COLOR};padding:14px 16px;border-radius:10px;box-shadow:0 1px 3px rgba(0,0,0,0.06);">
225
+ <div style="display:flex;gap:14px;align-items:flex-start;">
226
+ <div style="font-size:22px;">{icon}</div>
227
+ <div>
228
+ <div style="font-weight:700;color:{PRIMARY_COLOR};margin-bottom:4px;">{title}</div>
229
+ <div style="color:#000;font-size:14px;line-height:1.5;">{desc}</div>
230
+ <div style="margin-top:8px;color:#000;font-size:13px;">{meta_html}</div>
231
+ </div>
232
+ </div>
233
+ </div>
234
+ </div>
235
+ """
236
+
237
+ def render_disclaimer(text: str, icon: str = "⚠️", title: str = "Educational Use Only") -> str:
238
+ return f"""
239
+ <div style=\"margin: 8px 0 6px 0;\">
240
+ <div style=\"background:#FFF4F4;border-left:6px solid {ACCENT_COLOR};padding:12px 16px;border-radius:8px;box-shadow:0 1px 3px rgba(0,0,0,0.06);\">
241
+ <div style=\"display:flex;gap:10px;align-items:flex-start;color:#000;\">
242
+ <span style=\"font-size:20px\">{icon}</span>
243
+ <div>
244
+ <div style=\"font-weight:700; margin-bottom:4px;\">{title}</div>
245
+ <div style=\"font-size:14px; line-height:1.4;\">{text}</div>
246
+ </div>
247
+ </div>
248
+ </div>
249
+ </div>
250
+ """