modify logic app
Browse files- app.py +190 -104
- src/heart_disease_core.py +92 -41
- vlai_template.py +147 -39
app.py
CHANGED
|
@@ -2,14 +2,15 @@ import os
|
|
| 2 |
import gradio as gr
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import pandas as pd
|
|
|
|
| 5 |
|
| 6 |
from src.heart_disease_core import (
|
| 7 |
-
CLEVELAND_FEATURES_ORDER,
|
| 8 |
-
load_cleveland_dataframe, fit_all_models, predict_all, example_patient
|
| 9 |
)
|
| 10 |
|
| 11 |
-
APP_PRIMARY =
|
| 12 |
-
APP_ACCENT =
|
| 13 |
APP_BG = "#F7FAFC"
|
| 14 |
|
| 15 |
STATE = {
|
|
@@ -20,21 +21,35 @@ STATE = {
|
|
| 20 |
|
| 21 |
DATA_PATH = "data/cleveland.csv"
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
# -----------------------------
|
| 25 |
-
# Startup / init
|
| 26 |
-
# -----------------------------
|
| 27 |
def init_page():
|
| 28 |
-
"""
|
| 29 |
-
Load dataset from disk, fit models once, and return preview + metrics.
|
| 30 |
-
Returns plain values (no .update), to maximize Gradio compatibility.
|
| 31 |
-
"""
|
| 32 |
if not os.path.exists(DATA_PATH):
|
| 33 |
msg = f"❌ Dataset not found at '{DATA_PATH}'. Please place Cleveland CSV there."
|
| 34 |
return msg, pd.DataFrame(), pd.DataFrame()
|
| 35 |
|
| 36 |
-
|
| 37 |
-
df = load_cleveland_dataframe(uploaded_df=raw) # cleans, binarizes target
|
| 38 |
|
| 39 |
models, metrics = fit_all_models(df)
|
| 40 |
STATE["df"] = df
|
|
@@ -46,42 +61,60 @@ def init_page():
|
|
| 46 |
return msg, head, metrics
|
| 47 |
|
| 48 |
|
| 49 |
-
# -----------------------------
|
| 50 |
-
# Helpers
|
| 51 |
-
# -----------------------------
|
| 52 |
def fill_example(idx_text: str):
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
ex = example_patient(idx)
|
| 59 |
-
# Return in the strict feature order so Gradio can assign to outputs 1:1
|
| 60 |
return [ex[c] for c in CLEVELAND_FEATURES_ORDER]
|
| 61 |
|
| 62 |
|
| 63 |
def _bar_for_models(results: dict):
|
| 64 |
names = list(results.keys())
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
fig = go.Figure()
|
| 68 |
-
fig.add_bar(x=names, y=
|
| 69 |
fig.update_layout(
|
| 70 |
-
title="Model
|
| 71 |
-
yaxis_title="
|
| 72 |
xaxis_title="Model",
|
| 73 |
yaxis=dict(range=[0, 1]),
|
| 74 |
plot_bgcolor="white",
|
|
|
|
|
|
|
| 75 |
height=420,
|
| 76 |
margin=dict(l=30, r=20, t=60, b=40)
|
| 77 |
)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
colors[names.index("Ensemble (Soft Voting)")] = APP_ACCENT
|
| 82 |
-
elif len(colors) > 0:
|
| 83 |
-
colors[-1] = APP_ACCENT
|
| 84 |
-
fig.data[0].marker.color = colors
|
| 85 |
return fig
|
| 86 |
|
| 87 |
|
|
@@ -93,17 +126,43 @@ def run_predict(*vals):
|
|
| 93 |
results = predict_all(STATE["models"], input_dict)
|
| 94 |
|
| 95 |
final = results["Ensemble (Soft Voting)"]
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
rows = []
|
| 103 |
for name, r in results.items():
|
|
|
|
| 104 |
rows.append({
|
| 105 |
"Model": name,
|
| 106 |
-
"
|
|
|
|
| 107 |
"P(No disease)": round(r["prob_0"], 3),
|
| 108 |
"P(Heart disease)": round(r["prob_1"], 3),
|
| 109 |
})
|
|
@@ -111,80 +170,107 @@ def run_predict(*vals):
|
|
| 111 |
|
| 112 |
fig = _bar_for_models(results)
|
| 113 |
|
| 114 |
-
|
| 115 |
-
return title_md, fig, "**Per-Model Predictions**", table_df
|
| 116 |
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
.
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
gr.Markdown("
|
| 129 |
|
| 130 |
-
with gr.Row(equal_height=False):
|
| 131 |
# LEFT: data preview + inputs
|
| 132 |
with gr.Column(scale=45):
|
| 133 |
-
gr.
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
gr.
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# RIGHT: outputs
|
| 169 |
with gr.Column(scale=55):
|
| 170 |
-
gr.Markdown("### 📈 Predictions")
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
# Bind events
|
| 185 |
demo.load(fn=init_page, inputs=None, outputs=[status_md, preview, metrics_df])
|
| 186 |
|
| 187 |
-
|
|
|
|
| 188 |
fn=fill_example,
|
| 189 |
inputs=[ex_selector],
|
| 190 |
outputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal]
|
|
@@ -193,8 +279,8 @@ h1, h2, h3, h4 {{ color: {APP_PRIMARY}; }}
|
|
| 193 |
predict_btn.click(
|
| 194 |
fn=run_predict,
|
| 195 |
inputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal],
|
| 196 |
-
outputs=[
|
| 197 |
)
|
| 198 |
|
| 199 |
if __name__ == "__main__":
|
| 200 |
-
demo.launch()
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import pandas as pd
|
| 5 |
+
import vlai_template
|
| 6 |
|
| 7 |
from src.heart_disease_core import (
|
| 8 |
+
CLEVELAND_FEATURES_ORDER,
|
| 9 |
+
load_cleveland_dataframe, fit_all_models, predict_all, example_patient, get_example_labels
|
| 10 |
)
|
| 11 |
|
| 12 |
+
APP_PRIMARY = vlai_template.PRIMARY_COLOR
|
| 13 |
+
APP_ACCENT = vlai_template.ACCENT_COLOR
|
| 14 |
APP_BG = "#F7FAFC"
|
| 15 |
|
| 16 |
STATE = {
|
|
|
|
| 21 |
|
| 22 |
DATA_PATH = "data/cleveland.csv"
|
| 23 |
|
| 24 |
+
vlai_template.set_meta(
|
| 25 |
+
project_name="Heart Disease Diagnosis Project",
|
| 26 |
+
year="2025",
|
| 27 |
+
module="03",
|
| 28 |
+
description="Predict heart disease risk from patient data with ML models trained on the Cleveland dataset.",
|
| 29 |
+
meta_items=[
|
| 30 |
+
("Dataset", "Cleveland Heart Disease"),
|
| 31 |
+
("Models", "Decision Tree, k-NN, Naive Bayes"),
|
| 32 |
+
("Ensemble", "Soft Voting"),
|
| 33 |
+
],
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
force_light_theme_js = """
|
| 37 |
+
() => {
|
| 38 |
+
const params = new URLSearchParams(window.location.search);
|
| 39 |
+
if (!params.has('__theme')) {
|
| 40 |
+
params.set('__theme', 'light');
|
| 41 |
+
window.location.search = params.toString();
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
"""
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
def init_page():
|
| 47 |
+
"""Load dataset, train models, and return status, preview, metrics."""
|
|
|
|
|
|
|
|
|
|
| 48 |
if not os.path.exists(DATA_PATH):
|
| 49 |
msg = f"❌ Dataset not found at '{DATA_PATH}'. Please place Cleveland CSV there."
|
| 50 |
return msg, pd.DataFrame(), pd.DataFrame()
|
| 51 |
|
| 52 |
+
df = load_cleveland_dataframe(file_path=DATA_PATH)
|
|
|
|
| 53 |
|
| 54 |
models, metrics = fit_all_models(df)
|
| 55 |
STATE["df"] = df
|
|
|
|
| 61 |
return msg, head, metrics
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
| 64 |
def fill_example(idx_text: str):
|
| 65 |
+
import re
|
| 66 |
+
match = re.search(r'Example (\d+)', idx_text)
|
| 67 |
+
if match:
|
| 68 |
+
idx = int(match.group(1)) - 1
|
| 69 |
+
else:
|
| 70 |
+
idx = 1
|
| 71 |
+
|
| 72 |
ex = example_patient(idx)
|
|
|
|
| 73 |
return [ex[c] for c in CLEVELAND_FEATURES_ORDER]
|
| 74 |
|
| 75 |
|
| 76 |
def _bar_for_models(results: dict):
|
| 77 |
names = list(results.keys())
|
| 78 |
+
confidences = []
|
| 79 |
+
predictions_text = []
|
| 80 |
+
bar_colors = []
|
| 81 |
+
line_colors = []
|
| 82 |
+
line_widths = []
|
| 83 |
+
|
| 84 |
+
for n in names:
|
| 85 |
+
r = results[n]
|
| 86 |
+
if r["label"] == 1:
|
| 87 |
+
confidences.append(r["prob_1"])
|
| 88 |
+
predictions_text.append("🫀 Heart Disease")
|
| 89 |
+
bar_colors.append("#C4314B")
|
| 90 |
+
else:
|
| 91 |
+
confidences.append(r["prob_0"])
|
| 92 |
+
predictions_text.append("✅ No Heart Disease")
|
| 93 |
+
bar_colors.append("#2E7D32")
|
| 94 |
+
line_colors.append("rgba(0,0,0,0.15)")
|
| 95 |
+
line_widths.append(1.0)
|
| 96 |
+
|
| 97 |
+
if "Ensemble (Soft Voting)" in names:
|
| 98 |
+
idx = names.index("Ensemble (Soft Voting)")
|
| 99 |
+
line_colors[idx] = "#000000"
|
| 100 |
+
line_widths[idx] = 2.5
|
| 101 |
|
| 102 |
fig = go.Figure()
|
| 103 |
+
fig.add_bar(x=names, y=confidences, text=predictions_text, textposition="auto")
|
| 104 |
fig.update_layout(
|
| 105 |
+
title="Model Predictions",
|
| 106 |
+
yaxis_title="Prediction Confidence",
|
| 107 |
xaxis_title="Model",
|
| 108 |
yaxis=dict(range=[0, 1]),
|
| 109 |
plot_bgcolor="white",
|
| 110 |
+
paper_bgcolor="white",
|
| 111 |
+
font=dict(color="black", size=12),
|
| 112 |
height=420,
|
| 113 |
margin=dict(l=30, r=20, t=60, b=40)
|
| 114 |
)
|
| 115 |
+
fig.data[0].marker.color = bar_colors
|
| 116 |
+
fig.data[0].marker.line.color = line_colors
|
| 117 |
+
fig.data[0].marker.line.width = line_widths
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
return fig
|
| 119 |
|
| 120 |
|
|
|
|
| 126 |
results = predict_all(STATE["models"], input_dict)
|
| 127 |
|
| 128 |
final = results["Ensemble (Soft Voting)"]
|
| 129 |
+
ensemble_color = "#C4314B" if final["label"] == 1 else "#2E7D32"
|
| 130 |
+
ensemble_prediction = "🫀 **Heart Disease Detected**" if final["label"] == 1 else "✅ **No Heart Disease**"
|
| 131 |
+
|
| 132 |
+
ensemble_md = f"""
|
| 133 |
+
<div style=\"border: 3px solid {ensemble_color}; border-radius: 10px; padding: 20px; margin: 15px 0; background: white;\">
|
| 134 |
+
<h3 style=\"margin: 0 0 15px 0; color: {ensemble_color};\">🎯 Ensemble Prediction (Final Result)</h3>
|
| 135 |
+
<p style=\"margin: 10px 0; font-size: 18px; color: black;\"><strong>{ensemble_prediction}</strong></p>
|
| 136 |
+
<p style=\"margin: 5px 0; font-size: 16px; color: black;\"><strong>Confidence:</strong> {final['prob_1']:.1%}</p>
|
| 137 |
+
</div>
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
model_predictions = []
|
| 141 |
+
for name, r in results.items():
|
| 142 |
+
prediction_text = "🫀 **Heart Disease Detected**" if r["label"] == 1 else "✅ **No Heart Disease**"
|
| 143 |
+
confidence = r["prob_1"] if r["label"] == 1 else r["prob_0"]
|
| 144 |
+
color = "#C4314B" if r["label"] == 1 else "#2E7D32"
|
| 145 |
+
|
| 146 |
+
model_predictions.append(f"""
|
| 147 |
+
<div style=\"border: 2px solid {color}; border-radius: 8px; padding: 15px; margin: 10px 0; background: white;\">
|
| 148 |
+
<h4 style=\"margin: 0 0 10px 0; color: {color};\">{name}</h4>
|
| 149 |
+
<p style=\"margin: 5px 0; font-size: 16px; color: black;\"><strong>Prediction:</strong> {prediction_text}</p>
|
| 150 |
+
<p style=\"margin: 5px 0; font-size: 14px; color: black;\"><strong>Confidence:</strong> {confidence:.1%}</p>
|
| 151 |
+
<p style=\"margin: 5px 0; font-size: 12px; color: #666;\">
|
| 152 |
+
P(No disease): {r['prob_0']:.3f} | P(Heart disease): {r['prob_1']:.3f}
|
| 153 |
+
</p>
|
| 154 |
+
</div>
|
| 155 |
+
""")
|
| 156 |
+
|
| 157 |
+
all_predictions = "\n".join(model_predictions)
|
| 158 |
+
|
| 159 |
rows = []
|
| 160 |
for name, r in results.items():
|
| 161 |
+
confidence = r["prob_1"] if r["label"] == 1 else r["prob_0"]
|
| 162 |
rows.append({
|
| 163 |
"Model": name,
|
| 164 |
+
"Prediction": "Heart Disease" if r["label"] == 1 else "No Heart Disease",
|
| 165 |
+
"Confidence": f"{confidence:.1%}",
|
| 166 |
"P(No disease)": round(r["prob_0"], 3),
|
| 167 |
"P(Heart disease)": round(r["prob_1"], 3),
|
| 168 |
})
|
|
|
|
| 170 |
|
| 171 |
fig = _bar_for_models(results)
|
| 172 |
|
| 173 |
+
return fig, "\n".join(model_predictions), table_df
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
+
with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo:
|
| 177 |
+
vlai_template.create_header()
|
| 178 |
+
gr.HTML(vlai_template.render_info_card(icon="🫀", title="About this demo"))
|
| 179 |
+
gr.HTML(vlai_template.render_disclaimer(
|
| 180 |
+
text=(
|
| 181 |
+
"This interactive heart disease prediction demo is provided strictly for educational purposes. "
|
| 182 |
+
"It is not intended for clinical use and must not be relied upon for medical advice, diagnosis, "
|
| 183 |
+
"treatment, or decision-making. Always consult a qualified healthcare professional."
|
| 184 |
+
)
|
| 185 |
+
))
|
| 186 |
+
gr.Markdown("### 🫀 **How to Use**: Enter patient features → Run prediction → View ensemble results!")
|
| 187 |
|
| 188 |
+
with gr.Row(equal_height=False, variant="panel"):
|
| 189 |
# LEFT: data preview + inputs
|
| 190 |
with gr.Column(scale=45):
|
| 191 |
+
with gr.Accordion("📁 Dataset & Model Status", open=True):
|
| 192 |
+
status_md = gr.Markdown("Loading dataset and training models...")
|
| 193 |
+
preview = gr.DataFrame(label="Cleveland Preview (first rows)", interactive=False)
|
| 194 |
+
metrics_df = gr.DataFrame(label="Validation Metrics (80/20 split)", interactive=False)
|
| 195 |
+
|
| 196 |
+
with gr.Accordion("✍️ Enter Patient Features", open=True):
|
| 197 |
+
with gr.Row():
|
| 198 |
+
age = gr.Number(label="age (years)", value=58)
|
| 199 |
+
sex = gr.Dropdown(label="sex (0=female, 1=male)", choices=[0, 1], value=1)
|
| 200 |
+
cp = gr.Dropdown(label="cp (chest pain type 1..4)", choices=[1, 2, 3, 4], value=2)
|
| 201 |
+
trestbps = gr.Number(label="trestbps (resting BP mmHg)", value=130)
|
| 202 |
+
|
| 203 |
+
with gr.Row():
|
| 204 |
+
chol = gr.Number(label="chol (serum cholesterol mg/dl)", value=250)
|
| 205 |
+
fbs = gr.Dropdown(label="fbs (>120 mg/dl? 1/0)", choices=[0, 1], value=0)
|
| 206 |
+
restecg = gr.Dropdown(label="restecg (0..2)", choices=[0, 1, 2], value=1)
|
| 207 |
+
thalach = gr.Number(label="thalach (max heart rate)", value=150)
|
| 208 |
+
|
| 209 |
+
with gr.Row():
|
| 210 |
+
exang = gr.Dropdown(label="exang (exercise angina 1/0)", choices=[0, 1], value=0)
|
| 211 |
+
oldpeak = gr.Number(label="oldpeak (ST depression)", value=1.0)
|
| 212 |
+
slope = gr.Dropdown(label="slope (1..3)", choices=[1, 2, 3], value=1)
|
| 213 |
+
ca = gr.Dropdown(label="ca (major vessels 0..3)", choices=[0, 1, 2, 3], value=0)
|
| 214 |
+
|
| 215 |
+
thal = gr.Dropdown(label="thal (3=normal, 6=fixed, 7=reversible)", choices=[3, 6, 7], value=3)
|
| 216 |
+
|
| 217 |
+
with gr.Row():
|
| 218 |
+
# Get actual labels from the dataset - only 2 examples
|
| 219 |
+
try:
|
| 220 |
+
labels = get_example_labels()
|
| 221 |
+
choices = []
|
| 222 |
+
# Only use first two examples: one no disease, one disease
|
| 223 |
+
for i in range(min(2, len(labels))):
|
| 224 |
+
label_text = "No Heart Disease" if labels[i] == 0 else "Heart Disease"
|
| 225 |
+
choices.append(f"Example {i+1} ({label_text})")
|
| 226 |
+
default_choice = choices[0] if choices else "Example 1"
|
| 227 |
+
except:
|
| 228 |
+
choices = ["Example 1 (No Heart Disease)", "Example 2 (Heart Disease)"]
|
| 229 |
+
default_choice = "Example 1 (No Heart Disease)"
|
| 230 |
+
|
| 231 |
+
ex_selector = gr.Dropdown(
|
| 232 |
+
label="Select Example Patient",
|
| 233 |
+
choices=choices,
|
| 234 |
+
value=default_choice
|
| 235 |
+
)
|
| 236 |
+
predict_btn = gr.Button("🔍 Predict", variant="primary")
|
| 237 |
|
| 238 |
# RIGHT: outputs
|
| 239 |
with gr.Column(scale=55):
|
| 240 |
+
gr.Markdown("### 📈 Model Predictions")
|
| 241 |
+
bar_out = gr.Plot(label="Model Predictions Overview")
|
| 242 |
+
sub_md = gr.Markdown("**Individual Model Results**")
|
| 243 |
+
table_out = gr.DataFrame(label="All Model Predictions", interactive=False)
|
| 244 |
+
|
| 245 |
+
gr.Markdown("""
|
| 246 |
+
## 📋 **Notes**
|
| 247 |
+
|
| 248 |
+
- **Models are trained once at launch** on `data/cleveland.csv` (80/20 split).
|
| 249 |
+
- **Target is binarized automatically** (0 = no disease, >0 = disease).
|
| 250 |
+
- **Ensemble uses soft voting** over Decision Tree, k-NN, and Naive Bayes.
|
| 251 |
+
- **Feature descriptions**:
|
| 252 |
+
- `age`: Patient age in years
|
| 253 |
+
- `sex`: Gender (0=female, 1=male)
|
| 254 |
+
- `cp`: Chest pain type (1-4)
|
| 255 |
+
- `trestbps`: Resting blood pressure (mmHg)
|
| 256 |
+
- `chol`: Serum cholesterol (mg/dl)
|
| 257 |
+
- `fbs`: Fasting blood sugar >120 mg/dl (1=true, 0=false)
|
| 258 |
+
- `restecg`: Resting ECG results (0-2)
|
| 259 |
+
- `thalach`: Maximum heart rate achieved
|
| 260 |
+
- `exang`: Exercise induced angina (1=yes, 0=no)
|
| 261 |
+
- `oldpeak`: ST depression induced by exercise
|
| 262 |
+
- `slope`: Slope of peak exercise ST segment (1-3)
|
| 263 |
+
- `ca`: Number of major vessels colored by fluoroscopy (0-3)
|
| 264 |
+
- `thal`: Thalassemia (3=normal, 6=fixed defect, 7=reversible defect)
|
| 265 |
+
""")
|
| 266 |
+
|
| 267 |
+
vlai_template.create_footer()
|
| 268 |
|
| 269 |
# Bind events
|
| 270 |
demo.load(fn=init_page, inputs=None, outputs=[status_md, preview, metrics_df])
|
| 271 |
|
| 272 |
+
# Auto-fill when example is selected
|
| 273 |
+
ex_selector.change(
|
| 274 |
fn=fill_example,
|
| 275 |
inputs=[ex_selector],
|
| 276 |
outputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal]
|
|
|
|
| 279 |
predict_btn.click(
|
| 280 |
fn=run_predict,
|
| 281 |
inputs=[age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal],
|
| 282 |
+
outputs=[bar_out, sub_md, table_out]
|
| 283 |
)
|
| 284 |
|
| 285 |
if __name__ == "__main__":
|
| 286 |
+
demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"])
|
src/heart_disease_core.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# src/heart_disease_core.py
|
| 2 |
import os
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
|
@@ -9,7 +8,7 @@ from sklearn.preprocessing import OneHotEncoder
|
|
| 9 |
from sklearn.compose import ColumnTransformer
|
| 10 |
from sklearn.pipeline import Pipeline
|
| 11 |
from sklearn.impute import SimpleImputer
|
| 12 |
-
from sklearn.metrics import roc_auc_score
|
| 13 |
from sklearn.tree import DecisionTreeClassifier
|
| 14 |
from sklearn.neighbors import KNeighborsClassifier
|
| 15 |
from sklearn.naive_bayes import GaussianNB
|
|
@@ -20,48 +19,41 @@ CLEVELAND_FEATURES_ORDER: List[str] = [
|
|
| 20 |
"age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
|
| 21 |
"thalach", "exang", "oldpeak", "slope", "ca", "thal"
|
| 22 |
]
|
| 23 |
-
TARGET_COL = "target"
|
| 24 |
|
| 25 |
CATEGORICAL_CHOICES = {
|
| 26 |
-
"sex": [0, 1],
|
| 27 |
-
"cp": [0, 1, 2, 3],
|
| 28 |
-
"fbs": [0, 1],
|
| 29 |
-
"restecg": [0, 1, 2],
|
| 30 |
-
"exang": [0, 1],
|
| 31 |
-
"slope": [0, 1, 2],
|
| 32 |
-
"ca": [0, 1, 2, 3],
|
| 33 |
-
"thal": [1, 2, 3],
|
| 34 |
}
|
| 35 |
|
| 36 |
NUMERIC_COLS = ["age", "trestbps", "chol", "thalach", "oldpeak"]
|
| 37 |
CATEGORICAL_COLS = ["sex", "cp", "fbs", "restecg", "exang", "slope", "ca", "thal"]
|
| 38 |
|
| 39 |
def _coerce_and_clean(df: pd.DataFrame) -> pd.DataFrame:
|
| 40 |
-
"""Clean '?'
|
| 41 |
df = df.copy()
|
| 42 |
-
# Standardize columns if they are present with any case
|
| 43 |
colmap = {c.lower(): c for c in df.columns}
|
| 44 |
for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
|
| 45 |
if col not in df.columns and col in colmap:
|
| 46 |
-
df[col] = df.pop(colmap[col])
|
| 47 |
|
| 48 |
-
# Replace '?' with NaN and cast
|
| 49 |
for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
|
| 50 |
if col in df.columns:
|
| 51 |
df[col] = pd.to_numeric(df[col].replace("?", np.nan), errors="coerce")
|
| 52 |
|
| 53 |
-
# Binarize target if it appears as 0..4 (UCI often uses 0 vs 1..4 disease)
|
| 54 |
if TARGET_COL in df.columns:
|
| 55 |
df[TARGET_COL] = (df[TARGET_COL] > 0).astype(int)
|
| 56 |
|
| 57 |
return df
|
| 58 |
|
| 59 |
def load_cleveland_dataframe(file_path: Optional[str] = None, uploaded_df: Optional[pd.DataFrame] = None) -> pd.DataFrame:
|
| 60 |
-
"""
|
| 61 |
-
Load the Cleveland Heart Disease dataset.
|
| 62 |
-
Priority: uploaded_df > file_path > raise.
|
| 63 |
-
Expect columns CLEVELAND_FEATURES_ORDER + TARGET_COL.
|
| 64 |
-
"""
|
| 65 |
if uploaded_df is not None:
|
| 66 |
df = _coerce_and_clean(uploaded_df)
|
| 67 |
missing = [c for c in CLEVELAND_FEATURES_ORDER + [TARGET_COL] if c not in df.columns]
|
|
@@ -71,7 +63,19 @@ def load_cleveland_dataframe(file_path: Optional[str] = None, uploaded_df: Optio
|
|
| 71 |
|
| 72 |
if file_path is not None and os.path.exists(file_path):
|
| 73 |
if file_path.endswith(".csv"):
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
else:
|
| 76 |
df = pd.read_excel(file_path)
|
| 77 |
df = _coerce_and_clean(df)
|
|
@@ -171,14 +175,29 @@ def fit_all_models(df: pd.DataFrame, test_size: float = 0.2, random_state: int =
|
|
| 171 |
|
| 172 |
for name, pipe in models.items():
|
| 173 |
pipe.fit(X_tr, y_tr)
|
|
|
|
|
|
|
| 174 |
if hasattr(pipe, "predict_proba"):
|
| 175 |
proba = pipe.predict_proba(X_te)[:, 1]
|
| 176 |
auc = roc_auc_score(y_te, proba)
|
| 177 |
else:
|
| 178 |
-
# Fallback if
|
| 179 |
-
|
| 180 |
-
auc = roc_auc_score(y_te,
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
metrics_df = pd.DataFrame(metrics).sort_values("ROC-AUC", ascending=False, ignore_index=True)
|
| 184 |
return models, metrics_df
|
|
@@ -210,19 +229,51 @@ def predict_all(models: Dict[str, Pipeline], input_dict: Dict[str, float]) -> Di
|
|
| 210 |
|
| 211 |
def example_patient(index: int = 0) -> Dict[str, float]:
|
| 212 |
"""
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
"""
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
thalach=168, exang=0, oldpeak=0.0, slope=2, ca=0, thal=2),
|
| 220 |
-
# Borderline
|
| 221 |
-
dict(age=58, sex=1, cp=2, trestbps=138, chol=250, fbs=0, restecg=0,
|
| 222 |
-
thalach=150, exang=0, oldpeak=1.0, slope=1, ca=1, thal=2),
|
| 223 |
-
# Likely positive (disease)
|
| 224 |
-
dict(age=63, sex=1, cp=3, trestbps=145, chol=320, fbs=1, restecg=2,
|
| 225 |
-
thalach=130, exang=1, oldpeak=2.8, slope=0, ca=2, thal=3),
|
| 226 |
-
]
|
| 227 |
-
index = max(0, min(index, len(examples) - 1))
|
| 228 |
-
return examples[index]
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
|
|
|
| 8 |
from sklearn.compose import ColumnTransformer
|
| 9 |
from sklearn.pipeline import Pipeline
|
| 10 |
from sklearn.impute import SimpleImputer
|
| 11 |
+
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
|
| 12 |
from sklearn.tree import DecisionTreeClassifier
|
| 13 |
from sklearn.neighbors import KNeighborsClassifier
|
| 14 |
from sklearn.naive_bayes import GaussianNB
|
|
|
|
| 19 |
"age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
|
| 20 |
"thalach", "exang", "oldpeak", "slope", "ca", "thal"
|
| 21 |
]
|
| 22 |
+
TARGET_COL = "target"
|
| 23 |
|
| 24 |
CATEGORICAL_CHOICES = {
|
| 25 |
+
"sex": [0, 1],
|
| 26 |
+
"cp": [0, 1, 2, 3],
|
| 27 |
+
"fbs": [0, 1],
|
| 28 |
+
"restecg": [0, 1, 2],
|
| 29 |
+
"exang": [0, 1],
|
| 30 |
+
"slope": [0, 1, 2],
|
| 31 |
+
"ca": [0, 1, 2, 3],
|
| 32 |
+
"thal": [1, 2, 3],
|
| 33 |
}
|
| 34 |
|
| 35 |
NUMERIC_COLS = ["age", "trestbps", "chol", "thalach", "oldpeak"]
|
| 36 |
CATEGORICAL_COLS = ["sex", "cp", "fbs", "restecg", "exang", "slope", "ca", "thal"]
|
| 37 |
|
| 38 |
def _coerce_and_clean(df: pd.DataFrame) -> pd.DataFrame:
|
| 39 |
+
"""Clean '?', cast numerics, normalize column names, and binarize target."""
|
| 40 |
df = df.copy()
|
|
|
|
| 41 |
colmap = {c.lower(): c for c in df.columns}
|
| 42 |
for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
|
| 43 |
if col not in df.columns and col in colmap:
|
| 44 |
+
df[col] = df.pop(colmap[col])
|
| 45 |
|
|
|
|
| 46 |
for col in CLEVELAND_FEATURES_ORDER + [TARGET_COL]:
|
| 47 |
if col in df.columns:
|
| 48 |
df[col] = pd.to_numeric(df[col].replace("?", np.nan), errors="coerce")
|
| 49 |
|
|
|
|
| 50 |
if TARGET_COL in df.columns:
|
| 51 |
df[TARGET_COL] = (df[TARGET_COL] > 0).astype(int)
|
| 52 |
|
| 53 |
return df
|
| 54 |
|
| 55 |
def load_cleveland_dataframe(file_path: Optional[str] = None, uploaded_df: Optional[pd.DataFrame] = None) -> pd.DataFrame:
|
| 56 |
+
"""Load Cleveland dataset from upload or file path and ensure schema."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
if uploaded_df is not None:
|
| 58 |
df = _coerce_and_clean(uploaded_df)
|
| 59 |
missing = [c for c in CLEVELAND_FEATURES_ORDER + [TARGET_COL] if c not in df.columns]
|
|
|
|
| 63 |
|
| 64 |
if file_path is not None and os.path.exists(file_path):
|
| 65 |
if file_path.endswith(".csv"):
|
| 66 |
+
# Try reading with headers first; fall back to no header
|
| 67 |
+
try:
|
| 68 |
+
df = pd.read_csv(file_path)
|
| 69 |
+
if len(df.columns) == len(CLEVELAND_FEATURES_ORDER) + 1: # +1 for target
|
| 70 |
+
first_row_numeric = all(pd.to_numeric(df.iloc[0], errors='coerce').notna())
|
| 71 |
+
if first_row_numeric:
|
| 72 |
+
# Re-read without headers and assign names
|
| 73 |
+
df = pd.read_csv(file_path, header=None)
|
| 74 |
+
df.columns = CLEVELAND_FEATURES_ORDER + [TARGET_COL]
|
| 75 |
+
except:
|
| 76 |
+
# Fallback: read without headers
|
| 77 |
+
df = pd.read_csv(file_path, header=None)
|
| 78 |
+
df.columns = CLEVELAND_FEATURES_ORDER + [TARGET_COL]
|
| 79 |
else:
|
| 80 |
df = pd.read_excel(file_path)
|
| 81 |
df = _coerce_and_clean(df)
|
|
|
|
| 175 |
|
| 176 |
for name, pipe in models.items():
|
| 177 |
pipe.fit(X_tr, y_tr)
|
| 178 |
+
# Predictions and probabilities
|
| 179 |
+
y_pred = pipe.predict(X_te)
|
| 180 |
if hasattr(pipe, "predict_proba"):
|
| 181 |
proba = pipe.predict_proba(X_te)[:, 1]
|
| 182 |
auc = roc_auc_score(y_te, proba)
|
| 183 |
else:
|
| 184 |
+
# Fallback if probabilities are not available
|
| 185 |
+
proba = None
|
| 186 |
+
auc = roc_auc_score(y_te, y_pred)
|
| 187 |
+
|
| 188 |
+
acc = accuracy_score(y_te, y_pred)
|
| 189 |
+
prec = precision_score(y_te, y_pred, zero_division=0)
|
| 190 |
+
rec = recall_score(y_te, y_pred, zero_division=0)
|
| 191 |
+
f1 = f1_score(y_te, y_pred, zero_division=0)
|
| 192 |
+
|
| 193 |
+
metrics.append({
|
| 194 |
+
"model": name,
|
| 195 |
+
"ROC-AUC": round(float(auc), 4),
|
| 196 |
+
"Accuracy": round(float(acc), 4),
|
| 197 |
+
"Precision": round(float(prec), 4),
|
| 198 |
+
"Recall": round(float(rec), 4),
|
| 199 |
+
"F1": round(float(f1), 4),
|
| 200 |
+
})
|
| 201 |
|
| 202 |
metrics_df = pd.DataFrame(metrics).sort_values("ROC-AUC", ascending=False, ignore_index=True)
|
| 203 |
return models, metrics_df
|
|
|
|
| 229 |
|
| 230 |
def example_patient(index: int = 0) -> Dict[str, float]:
|
| 231 |
"""
|
| 232 |
+
Get example patients with specific features provided by user.
|
| 233 |
+
"""
|
| 234 |
+
# Example 1: No heart disease (37,1,3,130,250,0,0,187,0,3.5,3,0,3,0)
|
| 235 |
+
# Example 2: Heart disease (56,1,3,130,256,1,2,142,1,0.6,2,1,6,2)
|
| 236 |
+
|
| 237 |
+
if index == 0:
|
| 238 |
+
# No heart disease example
|
| 239 |
+
return {
|
| 240 |
+
"age": 37.0,
|
| 241 |
+
"sex": 1.0,
|
| 242 |
+
"cp": 3.0,
|
| 243 |
+
"trestbps": 130.0,
|
| 244 |
+
"chol": 250.0,
|
| 245 |
+
"fbs": 0.0,
|
| 246 |
+
"restecg": 0.0,
|
| 247 |
+
"thalach": 187.0,
|
| 248 |
+
"exang": 0.0,
|
| 249 |
+
"oldpeak": 3.5,
|
| 250 |
+
"slope": 3.0,
|
| 251 |
+
"ca": 0.0,
|
| 252 |
+
"thal": 3.0
|
| 253 |
+
}
|
| 254 |
+
else:
|
| 255 |
+
# Heart disease example
|
| 256 |
+
return {
|
| 257 |
+
"age": 56.0,
|
| 258 |
+
"sex": 1.0,
|
| 259 |
+
"cp": 3.0,
|
| 260 |
+
"trestbps": 130.0,
|
| 261 |
+
"chol": 256.0,
|
| 262 |
+
"fbs": 1.0,
|
| 263 |
+
"restecg": 2.0,
|
| 264 |
+
"thalach": 142.0,
|
| 265 |
+
"exang": 1.0,
|
| 266 |
+
"oldpeak": 0.6,
|
| 267 |
+
"slope": 2.0,
|
| 268 |
+
"ca": 1.0,
|
| 269 |
+
"thal": 6.0
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
def get_example_labels() -> List[int]:
|
| 273 |
+
"""
|
| 274 |
+
Get the labels for the example patients to display in the UI.
|
| 275 |
+
Returns list of labels for the specific examples provided.
|
| 276 |
"""
|
| 277 |
+
# Example 1: No heart disease (target = 0)
|
| 278 |
+
# Example 2: Heart disease (target = 2, binarized to 1)
|
| 279 |
+
return [0, 1] # First example: no disease, second example: heart disease
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vlai_template.py
CHANGED
|
@@ -1,11 +1,73 @@
|
|
| 1 |
import os, base64
|
| 2 |
import gradio as gr
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
AIO_YEAR = "2025"
|
| 7 |
-
AIO_MODULE = "
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def image_to_base64(image_path: str):
|
|
@@ -28,7 +90,7 @@ def create_header():
|
|
| 28 |
gr.HTML(f"""
|
| 29 |
<div style="display:flex;justify-content:flex-start;align-items:center;gap:30px;">
|
| 30 |
<div>
|
| 31 |
-
<h1 style="margin-bottom:0; color:
|
| 32 |
<h3 style="color: #888; font-style: italic"> AIO{AIO_YEAR}: Module {AIO_MODULE}. </h3>
|
| 33 |
</div>
|
| 34 |
</div>
|
|
@@ -53,90 +115,136 @@ def create_footer():
|
|
| 53 |
"""
|
| 54 |
return gr.HTML(footer_html)
|
| 55 |
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
.gradio-container {
|
| 59 |
min-height: 100vh !important;
|
| 60 |
width: 100vw !important;
|
| 61 |
margin: 0 !important;
|
| 62 |
padding: 0px !important;
|
| 63 |
-
background: linear-gradient(135deg,
|
| 64 |
background-size: 600% 600%;
|
| 65 |
animation: gradientBG 7s ease infinite;
|
| 66 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
@keyframes gradientBG {
|
| 69 |
-
0% {background-position: 0% 50%;}
|
| 70 |
-
50% {background-position: 100% 50%;}
|
| 71 |
-
100% {background-position: 0% 50%;}
|
| 72 |
-
}
|
| 73 |
|
| 74 |
/* Minimize spacing and padding */
|
| 75 |
-
.content-wrap {
|
| 76 |
padding: 2px !important;
|
| 77 |
margin: 0 !important;
|
| 78 |
-
}
|
| 79 |
|
| 80 |
/* Reduce component spacing */
|
| 81 |
-
.gr-row {
|
| 82 |
gap: 5px !important;
|
| 83 |
margin: 2px 0 !important;
|
| 84 |
-
}
|
| 85 |
|
| 86 |
-
.gr-column {
|
| 87 |
gap: 4px !important;
|
| 88 |
padding: 4px !important;
|
| 89 |
-
}
|
| 90 |
|
| 91 |
/* Accordion optimization */
|
| 92 |
-
.gr-accordion {
|
| 93 |
margin: 4px 0 !important;
|
| 94 |
-
}
|
| 95 |
|
| 96 |
-
.gr-accordion .gr-accordion-content {
|
| 97 |
padding: 2px !important;
|
| 98 |
-
}
|
| 99 |
|
| 100 |
/* Form elements spacing */
|
| 101 |
-
.gr-form {
|
| 102 |
gap: 2px !important;
|
| 103 |
-
}
|
| 104 |
|
| 105 |
/* Button styling */
|
| 106 |
-
.gr-button {
|
| 107 |
margin: 2px 0 !important;
|
| 108 |
-
}
|
| 109 |
|
| 110 |
/* DataFrame optimization */
|
| 111 |
-
.gr-dataframe {
|
| 112 |
margin: 4px 0 !important;
|
| 113 |
-
}
|
| 114 |
|
| 115 |
/* Remove horizontal scroll from data preview */
|
| 116 |
-
.gr-dataframe .wrap {
|
| 117 |
overflow-x: auto !important;
|
| 118 |
max-width: 100% !important;
|
| 119 |
-
}
|
| 120 |
|
| 121 |
/* Plot optimization */
|
| 122 |
-
.gr-plot {
|
| 123 |
margin: 4px 0 !important;
|
| 124 |
-
}
|
| 125 |
|
| 126 |
/* Reduce markdown margins */
|
| 127 |
-
.gr-markdown {
|
| 128 |
margin: 2px 0 !important;
|
| 129 |
-
}
|
| 130 |
|
| 131 |
/* Footer positioning */
|
| 132 |
-
.sticky-footer {
|
| 133 |
position: fixed;
|
| 134 |
bottom: 0px;
|
| 135 |
left: 0;
|
| 136 |
width: 100%;
|
| 137 |
-
background:
|
| 138 |
padding: 6px !important;
|
| 139 |
box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
|
| 140 |
z-index: 1000;
|
| 141 |
-
}
|
| 142 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os, base64
|
| 2 |
import gradio as gr
|
| 3 |
|
| 4 |
+
# Theming (can be overridden by the host app)
|
| 5 |
+
PRIMARY_COLOR = "#0F6CBD" # medical calm blue
|
| 6 |
+
ACCENT_COLOR = "#C4314B" # medical alert red
|
| 7 |
+
SUCCESS_COLOR = "#2E7D32" # positive/ok
|
| 8 |
+
BG1 = "#F0F7FF"
|
| 9 |
+
BG2 = "#E8F0FA"
|
| 10 |
+
BG3 = "#DDE7F8"
|
| 11 |
+
FONT_FAMILY = "'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, 'Noto Sans', 'Liberation Sans', sans-serif"
|
| 12 |
+
|
| 13 |
+
# App metadata (overridable)
|
| 14 |
+
PROJECT_NAME = "Demo Project"
|
| 15 |
AIO_YEAR = "2025"
|
| 16 |
+
AIO_MODULE = "00"
|
| 17 |
+
PROJECT_DESCRIPTION = ""
|
| 18 |
+
META_INFO = [] # list of (label, value)
|
| 19 |
+
|
| 20 |
+
def set_colors(primary: str = None, accent: str = None, bg1: str = None, bg2: str = None, bg3: str = None):
|
| 21 |
+
"""Allow host app to set theme colors dynamically."""
|
| 22 |
+
global PRIMARY_COLOR, ACCENT_COLOR, BG1, BG2, BG3, custom_css
|
| 23 |
+
if primary:
|
| 24 |
+
PRIMARY_COLOR = primary
|
| 25 |
+
if accent:
|
| 26 |
+
ACCENT_COLOR = accent
|
| 27 |
+
if bg1:
|
| 28 |
+
BG1 = bg1
|
| 29 |
+
if bg2:
|
| 30 |
+
BG2 = bg2
|
| 31 |
+
if bg3:
|
| 32 |
+
BG3 = bg3
|
| 33 |
+
# Rebuild CSS with new colors
|
| 34 |
+
custom_css = _build_custom_css()
|
| 35 |
+
|
| 36 |
+
def set_font(font_family: str):
|
| 37 |
+
"""Allow host app to set a custom font stack (e.g., 'Inter', system fallbacks)."""
|
| 38 |
+
global FONT_FAMILY, custom_css
|
| 39 |
+
if font_family and isinstance(font_family, str):
|
| 40 |
+
FONT_FAMILY = font_family
|
| 41 |
+
custom_css = _build_custom_css()
|
| 42 |
+
|
| 43 |
+
def set_meta(project_name: str = None, year: str = None, module: str = None, description: str = None, meta_items: list = None):
|
| 44 |
+
"""Set project metadata used across the header and info sections."""
|
| 45 |
+
global PROJECT_NAME, AIO_YEAR, AIO_MODULE, PROJECT_DESCRIPTION, META_INFO
|
| 46 |
+
if project_name is not None:
|
| 47 |
+
PROJECT_NAME = project_name
|
| 48 |
+
if year is not None:
|
| 49 |
+
AIO_YEAR = year
|
| 50 |
+
if module is not None:
|
| 51 |
+
AIO_MODULE = module
|
| 52 |
+
if description is not None:
|
| 53 |
+
PROJECT_DESCRIPTION = description
|
| 54 |
+
if meta_items is not None:
|
| 55 |
+
META_INFO = meta_items
|
| 56 |
+
|
| 57 |
+
def configure(project_name: str = None, year: str = None, module: str = None, description: str = None,
|
| 58 |
+
colors: dict = None, font_family: str = None, meta_items: list = None):
|
| 59 |
+
"""One-call configuration for meta, theme, and font."""
|
| 60 |
+
if colors:
|
| 61 |
+
set_colors(
|
| 62 |
+
primary=colors.get("primary"),
|
| 63 |
+
accent=colors.get("accent"),
|
| 64 |
+
bg1=colors.get("bg1"),
|
| 65 |
+
bg2=colors.get("bg2"),
|
| 66 |
+
bg3=colors.get("bg3"),
|
| 67 |
+
)
|
| 68 |
+
if font_family:
|
| 69 |
+
set_font(font_family)
|
| 70 |
+
set_meta(project_name, year, module, description, meta_items)
|
| 71 |
|
| 72 |
|
| 73 |
def image_to_base64(image_path: str):
|
|
|
|
| 90 |
gr.HTML(f"""
|
| 91 |
<div style="display:flex;justify-content:flex-start;align-items:center;gap:30px;">
|
| 92 |
<div>
|
| 93 |
+
<h1 style="margin-bottom:0; color: {PRIMARY_COLOR}; font-size: 2.5em; font-weight: bold;"> {PROJECT_NAME} </h1>
|
| 94 |
<h3 style="color: #888; font-style: italic"> AIO{AIO_YEAR}: Module {AIO_MODULE}. </h3>
|
| 95 |
</div>
|
| 96 |
</div>
|
|
|
|
| 115 |
"""
|
| 116 |
return gr.HTML(footer_html)
|
| 117 |
|
| 118 |
+
def _build_custom_css() -> str:
|
| 119 |
+
return f"""
|
| 120 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
|
| 121 |
|
| 122 |
+
.gradio-container {{
|
| 123 |
min-height: 100vh !important;
|
| 124 |
width: 100vw !important;
|
| 125 |
margin: 0 !important;
|
| 126 |
padding: 0px !important;
|
| 127 |
+
background: linear-gradient(135deg, {BG1} 0%, {BG2} 50%, {BG3} 100%);
|
| 128 |
background-size: 600% 600%;
|
| 129 |
animation: gradientBG 7s ease infinite;
|
| 130 |
+
}}
|
| 131 |
+
|
| 132 |
+
/* Global font setup */
|
| 133 |
+
body, .gradio-container, .gr-block, .gr-markdown, .gr-button, .gr-input,
|
| 134 |
+
.gr-dropdown, .gr-number, .gr-plot, .gr-dataframe, .gr-accordion, .gr-form,
|
| 135 |
+
.gr-textbox, .gr-html, table, th, td, label, h1, h2, h3, h4, h5, h6, p, span, div {{
|
| 136 |
+
font-family: {FONT_FAMILY} !important;
|
| 137 |
+
}}
|
| 138 |
|
| 139 |
+
@keyframes gradientBG {{
|
| 140 |
+
0% {{background-position: 0% 50%;}}
|
| 141 |
+
50% {{background-position: 100% 50%;}}
|
| 142 |
+
100% {{background-position: 0% 50%;}}
|
| 143 |
+
}}
|
| 144 |
|
| 145 |
/* Minimize spacing and padding */
|
| 146 |
+
.content-wrap {{
|
| 147 |
padding: 2px !important;
|
| 148 |
margin: 0 !important;
|
| 149 |
+
}}
|
| 150 |
|
| 151 |
/* Reduce component spacing */
|
| 152 |
+
.gr-row {{
|
| 153 |
gap: 5px !important;
|
| 154 |
margin: 2px 0 !important;
|
| 155 |
+
}}
|
| 156 |
|
| 157 |
+
.gr-column {{
|
| 158 |
gap: 4px !important;
|
| 159 |
padding: 4px !important;
|
| 160 |
+
}}
|
| 161 |
|
| 162 |
/* Accordion optimization */
|
| 163 |
+
.gr-accordion {{
|
| 164 |
margin: 4px 0 !important;
|
| 165 |
+
}}
|
| 166 |
|
| 167 |
+
.gr-accordion .gr-accordion-content {{
|
| 168 |
padding: 2px !important;
|
| 169 |
+
}}
|
| 170 |
|
| 171 |
/* Form elements spacing */
|
| 172 |
+
.gr-form {{
|
| 173 |
gap: 2px !important;
|
| 174 |
+
}}
|
| 175 |
|
| 176 |
/* Button styling */
|
| 177 |
+
.gr-button {{
|
| 178 |
margin: 2px 0 !important;
|
| 179 |
+
}}
|
| 180 |
|
| 181 |
/* DataFrame optimization */
|
| 182 |
+
.gr-dataframe {{
|
| 183 |
margin: 4px 0 !important;
|
| 184 |
+
}}
|
| 185 |
|
| 186 |
/* Remove horizontal scroll from data preview */
|
| 187 |
+
.gr-dataframe .wrap {{
|
| 188 |
overflow-x: auto !important;
|
| 189 |
max-width: 100% !important;
|
| 190 |
+
}}
|
| 191 |
|
| 192 |
/* Plot optimization */
|
| 193 |
+
.gr-plot {{
|
| 194 |
margin: 4px 0 !important;
|
| 195 |
+
}}
|
| 196 |
|
| 197 |
/* Reduce markdown margins */
|
| 198 |
+
.gr-markdown {{
|
| 199 |
margin: 2px 0 !important;
|
| 200 |
+
}}
|
| 201 |
|
| 202 |
/* Footer positioning */
|
| 203 |
+
.sticky-footer {{
|
| 204 |
position: fixed;
|
| 205 |
bottom: 0px;
|
| 206 |
left: 0;
|
| 207 |
width: 100%;
|
| 208 |
+
background: {BG1};
|
| 209 |
padding: 6px !important;
|
| 210 |
box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
|
| 211 |
z-index: 1000;
|
| 212 |
+
}}
|
| 213 |
"""
|
| 214 |
+
|
| 215 |
+
# Initialize CSS using defaults
|
| 216 |
+
custom_css = _build_custom_css()
|
| 217 |
+
|
| 218 |
+
def render_info_card(description: str = None, meta_items: list = None, icon: str = "🧠", title: str = "About this demo") -> str:
|
| 219 |
+
desc = description if description is not None else PROJECT_DESCRIPTION
|
| 220 |
+
items = meta_items if meta_items is not None else META_INFO
|
| 221 |
+
meta_html = " · ".join([f"<span><strong>{k}</strong>: {v}</span>" for k, v in items]) if items else ""
|
| 222 |
+
return f"""
|
| 223 |
+
<div style="margin: 8px 0 8px 0;">
|
| 224 |
+
<div style="background:#F5F9FF;border-left:6px solid {PRIMARY_COLOR};padding:14px 16px;border-radius:10px;box-shadow:0 1px 3px rgba(0,0,0,0.06);">
|
| 225 |
+
<div style="display:flex;gap:14px;align-items:flex-start;">
|
| 226 |
+
<div style="font-size:22px;">{icon}</div>
|
| 227 |
+
<div>
|
| 228 |
+
<div style="font-weight:700;color:{PRIMARY_COLOR};margin-bottom:4px;">{title}</div>
|
| 229 |
+
<div style="color:#000;font-size:14px;line-height:1.5;">{desc}</div>
|
| 230 |
+
<div style="margin-top:8px;color:#000;font-size:13px;">{meta_html}</div>
|
| 231 |
+
</div>
|
| 232 |
+
</div>
|
| 233 |
+
</div>
|
| 234 |
+
</div>
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
def render_disclaimer(text: str, icon: str = "⚠️", title: str = "Educational Use Only") -> str:
|
| 238 |
+
return f"""
|
| 239 |
+
<div style=\"margin: 8px 0 6px 0;\">
|
| 240 |
+
<div style=\"background:#FFF4F4;border-left:6px solid {ACCENT_COLOR};padding:12px 16px;border-radius:8px;box-shadow:0 1px 3px rgba(0,0,0,0.06);\">
|
| 241 |
+
<div style=\"display:flex;gap:10px;align-items:flex-start;color:#000;\">
|
| 242 |
+
<span style=\"font-size:20px\">{icon}</span>
|
| 243 |
+
<div>
|
| 244 |
+
<div style=\"font-weight:700; margin-bottom:4px;\">{title}</div>
|
| 245 |
+
<div style=\"font-size:14px; line-height:1.4;\">{text}</div>
|
| 246 |
+
</div>
|
| 247 |
+
</div>
|
| 248 |
+
</div>
|
| 249 |
+
</div>
|
| 250 |
+
"""
|