Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -54,7 +54,6 @@ def extract_beats(record, annotation, window_size=257):
|
|
| 54 |
# Function to detect the last Conv1D layer in the model
|
| 55 |
def get_last_conv_layer_name(model):
|
| 56 |
last_conv_layer = None
|
| 57 |
-
# Loop in reverse order over layers to find a Conv1D layer
|
| 58 |
for layer in reversed(model.layers):
|
| 59 |
if isinstance(layer, tf.keras.layers.Conv1D):
|
| 60 |
last_conv_layer = layer.name
|
|
@@ -65,13 +64,6 @@ def get_last_conv_layer_name(model):
|
|
| 65 |
|
| 66 |
# Function to generate Grad-CAM heatmap for a given beat and class index
|
| 67 |
def generate_grad_cam(model, sample, layer_name):
|
| 68 |
-
"""
|
| 69 |
-
model : your loaded Keras model
|
| 70 |
-
sample : a 4D tensor of shape (1, window_size, 1)
|
| 71 |
-
layer_name : name of the Conv1D layer to use for Grad‑CAM
|
| 72 |
-
returns : 1D numpy heatmap of length window_size
|
| 73 |
-
"""
|
| 74 |
-
# Build a model that returns both the conv outputs and the predictions
|
| 75 |
grad_model = tf.keras.models.Model(
|
| 76 |
inputs=model.inputs,
|
| 77 |
outputs=[model.get_layer(layer_name).output, model.output]
|
|
@@ -79,45 +71,40 @@ def generate_grad_cam(model, sample, layer_name):
|
|
| 79 |
|
| 80 |
with tf.GradientTape() as tape:
|
| 81 |
conv_outputs, predictions = grad_model(sample)
|
| 82 |
-
# pick the top predicted class
|
| 83 |
class_idx = tf.argmax(predictions[0])
|
| 84 |
loss = predictions[:, class_idx]
|
| 85 |
|
| 86 |
-
# gradient of the loss wrt conv outputs
|
| 87 |
grads = tape.gradient(loss, conv_outputs)
|
| 88 |
-
|
| 89 |
-
# global average pool the gradients to get the importance of each channel
|
| 90 |
-
pooled_grads = tf.reduce_mean(grads, axis=(0, 1)) # shape = (channels,)
|
| 91 |
-
|
| 92 |
-
# remove batch dim from conv_outputs -> (time, channels)
|
| 93 |
conv_outputs = tf.squeeze(conv_outputs, axis=0)
|
| 94 |
-
|
| 95 |
-
# weight the conv outputs by the pooled gradients
|
| 96 |
-
cam = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1) # shape = (time,)
|
| 97 |
raw = cam.numpy()
|
| 98 |
print("raw min/max:", raw.min(), raw.max())
|
| 99 |
|
| 100 |
-
cam = tf.abs(cam)
|
| 101 |
-
cam = cam / (tf.reduce_max(cam) + 1e-8)
|
| 102 |
|
| 103 |
return cam.numpy()
|
| 104 |
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
# Streamlit App Layout
|
| 108 |
st.title("ECG Arrhythmia Classification with Grad-CAM Visualization")
|
| 109 |
st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
|
| 110 |
-
record_loaded = False
|
| 111 |
-
record = None
|
| 112 |
-
annotation = None
|
| 113 |
|
| 114 |
# Load Record 108 Button
|
| 115 |
if st.button("Load Record 108"):
|
| 116 |
try:
|
| 117 |
base_name = "108"
|
| 118 |
-
record = wfdb.rdrecord(base_name)
|
| 119 |
-
annotation = wfdb.rdann(base_name, 'atr')
|
| 120 |
-
record_loaded = True
|
| 121 |
except Exception as e:
|
| 122 |
st.error(f"Error loading Record 108: {str(e)}")
|
| 123 |
|
|
@@ -128,7 +115,7 @@ uploaded_files = st.file_uploader(
|
|
| 128 |
accept_multiple_files=True
|
| 129 |
)
|
| 130 |
|
| 131 |
-
if uploaded_files and not record_loaded:
|
| 132 |
with tempfile.TemporaryDirectory() as tmpdir:
|
| 133 |
for f in uploaded_files:
|
| 134 |
file_path = os.path.join(tmpdir, f.name)
|
|
@@ -137,15 +124,15 @@ if uploaded_files and not record_loaded:
|
|
| 137 |
base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
|
| 138 |
common_base = list(base_names)[0]
|
| 139 |
try:
|
| 140 |
-
record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
|
| 141 |
-
annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
|
| 142 |
-
record_loaded = True
|
| 143 |
except Exception as e:
|
| 144 |
st.error(f"Error reading uploaded files: {str(e)}")
|
| 145 |
|
| 146 |
# Process the record if loaded
|
| 147 |
-
if record_loaded and record is not None and annotation is not None:
|
| 148 |
-
beats = extract_beats(record, annotation)
|
| 149 |
if len(beats) == 0:
|
| 150 |
st.error("No valid beats found in the record")
|
| 151 |
st.stop()
|
|
@@ -162,7 +149,6 @@ if record_loaded and record is not None and annotation is not None:
|
|
| 162 |
})
|
| 163 |
st.dataframe(results)
|
| 164 |
|
| 165 |
-
# Class Distribution Section
|
| 166 |
st.subheader("Class Distribution")
|
| 167 |
class_indices = list(class_map.keys())
|
| 168 |
class_names = [class_map[i] for i in class_indices]
|
|
@@ -177,62 +163,40 @@ if record_loaded and record is not None and annotation is not None:
|
|
| 177 |
with col2:
|
| 178 |
st.bar_chart(distribution_df.set_index('Class'))
|
| 179 |
|
| 180 |
-
# Display a Sample ECG Beat
|
| 181 |
st.subheader("Sample ECG Beat")
|
| 182 |
fig, ax = plt.subplots()
|
| 183 |
ax.plot(beats[0].flatten(), label="ECG Beat")
|
| 184 |
ax.legend()
|
| 185 |
st.pyplot(fig)
|
| 186 |
|
| 187 |
-
# ---------------- Grad-CAM Visualization Section ----------------
|
| 188 |
st.subheader("Class Comparison with Grad-CAM")
|
| 189 |
st.write("Compare model explanations between classes present in this record")
|
| 190 |
|
| 191 |
-
# Automatically detect the last convolutional layer name
|
| 192 |
conv_layer_name = get_last_conv_layer_name(model)
|
| 193 |
if conv_layer_name is not None:
|
| 194 |
st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
|
| 195 |
|
| 196 |
-
# Get classes actually present in the data
|
| 197 |
present_classes = distribution_df[distribution_df['Count'] > 0]['Class'].tolist()
|
| 198 |
if not present_classes:
|
| 199 |
st.warning("No classes with detected beats to compare")
|
| 200 |
st.stop()
|
| 201 |
|
| 202 |
-
# Class selection dropdowns
|
| 203 |
col1, col2, col3 = st.columns([1, 1, 1])
|
| 204 |
with col1:
|
| 205 |
-
left_class = st.selectbox(
|
| 206 |
-
"Left Class:",
|
| 207 |
-
options=present_classes,
|
| 208 |
-
index=0
|
| 209 |
-
)
|
| 210 |
with col2:
|
| 211 |
-
# Default to second class if available, else first
|
| 212 |
right_index = 1 if len(present_classes) > 1 else 0
|
| 213 |
-
right_class = st.selectbox(
|
| 214 |
-
"Right Class:",
|
| 215 |
-
options=present_classes,
|
| 216 |
-
index=right_index
|
| 217 |
-
)
|
| 218 |
with col3:
|
| 219 |
-
num_beats = st.number_input(
|
| 220 |
-
"Beats per class:",
|
| 221 |
-
min_value=1,
|
| 222 |
-
max_value=10,
|
| 223 |
-
value=3
|
| 224 |
-
)
|
| 225 |
|
| 226 |
-
# Get class indices from names
|
| 227 |
class_name_to_idx = {v: k for k, v in class_map.items()}
|
| 228 |
left_class_idx = class_name_to_idx[left_class]
|
| 229 |
right_class_idx = class_name_to_idx[right_class]
|
| 230 |
left_indices = np.where(predicted_classes == left_class_idx)[0]
|
| 231 |
right_indices = np.where(predicted_classes == right_class_idx)[0]
|
| 232 |
|
| 233 |
-
# Create comparison columns
|
| 234 |
left_col, right_col = st.columns(2)
|
| 235 |
-
|
| 236 |
def display_class_beats(col, class_name, beat_indices, num_beats):
|
| 237 |
with col:
|
| 238 |
st.subheader(class_name)
|
|
@@ -241,43 +205,26 @@ if record_loaded and record is not None and annotation is not None:
|
|
| 241 |
return
|
| 242 |
|
| 243 |
for beat_idx in beat_indices[:num_beats]:
|
| 244 |
-
beat = beats[beat_idx].flatten()
|
| 245 |
sample = beat.reshape(1, -1, 1).astype(np.float32)
|
| 246 |
-
|
| 247 |
-
# generate the 1D heatmap
|
| 248 |
heatmap = generate_grad_cam(model, sample, conv_layer_name)
|
| 249 |
-
|
| 250 |
-
# set up figure
|
| 251 |
fig, ax = plt.subplots(figsize=(8, 2))
|
| 252 |
y_min, y_max = beat.min(), beat.max()
|
| 253 |
-
|
| 254 |
-
# Always draw the heatmap background for all beats
|
| 255 |
ax.imshow(
|
| 256 |
-
np.expand_dims(heatmap, axis=0),
|
| 257 |
aspect='auto',
|
| 258 |
cmap='jet',
|
| 259 |
alpha=0.5,
|
| 260 |
extent=[0, len(beat), y_min, y_max]
|
| 261 |
)
|
| 262 |
-
|
| 263 |
-
# overlay the ECG trace
|
| 264 |
ax.plot(beat, linewidth=2, color='blue')
|
| 265 |
-
|
| 266 |
-
# styling
|
| 267 |
-
# Do NOT set a facecolor here - it will block the heatmap
|
| 268 |
-
# ax.set_facecolor('#e0e0f0') # This line is commented out
|
| 269 |
-
ax.axis('off') # clean look
|
| 270 |
ax.set_title(f"Beat {beat_idx}")
|
| 271 |
ax.set_xlim(0, len(beat))
|
| 272 |
ax.set_ylim(y_min, y_max)
|
| 273 |
-
|
| 274 |
st.pyplot(fig)
|
| 275 |
-
# Display left class beats
|
| 276 |
display_class_beats(left_col, left_class, left_indices, num_beats)
|
| 277 |
-
|
| 278 |
-
# Display right class beats
|
| 279 |
display_class_beats(right_col, right_class, right_indices, num_beats)
|
| 280 |
|
| 281 |
-
# Add comparison note if same class selected
|
| 282 |
if left_class == right_class:
|
| 283 |
st.info("Comparing different instances of the same class. Note: This shows intra-class variation.")
|
|
|
|
| 54 |
# Function to detect the last Conv1D layer in the model
|
| 55 |
def get_last_conv_layer_name(model):
|
| 56 |
last_conv_layer = None
|
|
|
|
| 57 |
for layer in reversed(model.layers):
|
| 58 |
if isinstance(layer, tf.keras.layers.Conv1D):
|
| 59 |
last_conv_layer = layer.name
|
|
|
|
| 64 |
|
| 65 |
# Function to generate Grad-CAM heatmap for a given beat and class index
|
| 66 |
def generate_grad_cam(model, sample, layer_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
grad_model = tf.keras.models.Model(
|
| 68 |
inputs=model.inputs,
|
| 69 |
outputs=[model.get_layer(layer_name).output, model.output]
|
|
|
|
| 71 |
|
| 72 |
with tf.GradientTape() as tape:
|
| 73 |
conv_outputs, predictions = grad_model(sample)
|
|
|
|
| 74 |
class_idx = tf.argmax(predictions[0])
|
| 75 |
loss = predictions[:, class_idx]
|
| 76 |
|
|
|
|
| 77 |
grads = tape.gradient(loss, conv_outputs)
|
| 78 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
conv_outputs = tf.squeeze(conv_outputs, axis=0)
|
| 80 |
+
cam = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1)
|
|
|
|
|
|
|
| 81 |
raw = cam.numpy()
|
| 82 |
print("raw min/max:", raw.min(), raw.max())
|
| 83 |
|
| 84 |
+
cam = tf.abs(cam)
|
| 85 |
+
cam = cam / (tf.reduce_max(cam) + 1e-8)
|
| 86 |
|
| 87 |
return cam.numpy()
|
| 88 |
|
| 89 |
+
# Initialize session state variables if not already set
|
| 90 |
+
if 'record_loaded' not in st.session_state:
|
| 91 |
+
st.session_state.record_loaded = False
|
| 92 |
+
if 'record' not in st.session_state:
|
| 93 |
+
st.session_state.record = None
|
| 94 |
+
if 'annotation' not in st.session_state:
|
| 95 |
+
st.session_state.annotation = None
|
| 96 |
|
| 97 |
# Streamlit App Layout
|
| 98 |
st.title("ECG Arrhythmia Classification with Grad-CAM Visualization")
|
| 99 |
st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# Load Record 108 Button
|
| 102 |
if st.button("Load Record 108"):
|
| 103 |
try:
|
| 104 |
base_name = "108"
|
| 105 |
+
st.session_state.record = wfdb.rdrecord(base_name)
|
| 106 |
+
st.session_state.annotation = wfdb.rdann(base_name, 'atr')
|
| 107 |
+
st.session_state.record_loaded = True
|
| 108 |
except Exception as e:
|
| 109 |
st.error(f"Error loading Record 108: {str(e)}")
|
| 110 |
|
|
|
|
| 115 |
accept_multiple_files=True
|
| 116 |
)
|
| 117 |
|
| 118 |
+
if uploaded_files and not st.session_state.record_loaded:
|
| 119 |
with tempfile.TemporaryDirectory() as tmpdir:
|
| 120 |
for f in uploaded_files:
|
| 121 |
file_path = os.path.join(tmpdir, f.name)
|
|
|
|
| 124 |
base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
|
| 125 |
common_base = list(base_names)[0]
|
| 126 |
try:
|
| 127 |
+
st.session_state.record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
|
| 128 |
+
st.session_state.annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
|
| 129 |
+
st.session_state.record_loaded = True
|
| 130 |
except Exception as e:
|
| 131 |
st.error(f"Error reading uploaded files: {str(e)}")
|
| 132 |
|
| 133 |
# Process the record if loaded
|
| 134 |
+
if st.session_state.record_loaded and st.session_state.record is not None and st.session_state.annotation is not None:
|
| 135 |
+
beats = extract_beats(st.session_state.record, st.session_state.annotation)
|
| 136 |
if len(beats) == 0:
|
| 137 |
st.error("No valid beats found in the record")
|
| 138 |
st.stop()
|
|
|
|
| 149 |
})
|
| 150 |
st.dataframe(results)
|
| 151 |
|
|
|
|
| 152 |
st.subheader("Class Distribution")
|
| 153 |
class_indices = list(class_map.keys())
|
| 154 |
class_names = [class_map[i] for i in class_indices]
|
|
|
|
| 163 |
with col2:
|
| 164 |
st.bar_chart(distribution_df.set_index('Class'))
|
| 165 |
|
|
|
|
| 166 |
st.subheader("Sample ECG Beat")
|
| 167 |
fig, ax = plt.subplots()
|
| 168 |
ax.plot(beats[0].flatten(), label="ECG Beat")
|
| 169 |
ax.legend()
|
| 170 |
st.pyplot(fig)
|
| 171 |
|
|
|
|
| 172 |
st.subheader("Class Comparison with Grad-CAM")
|
| 173 |
st.write("Compare model explanations between classes present in this record")
|
| 174 |
|
|
|
|
| 175 |
conv_layer_name = get_last_conv_layer_name(model)
|
| 176 |
if conv_layer_name is not None:
|
| 177 |
st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
|
| 178 |
|
|
|
|
| 179 |
present_classes = distribution_df[distribution_df['Count'] > 0]['Class'].tolist()
|
| 180 |
if not present_classes:
|
| 181 |
st.warning("No classes with detected beats to compare")
|
| 182 |
st.stop()
|
| 183 |
|
|
|
|
| 184 |
col1, col2, col3 = st.columns([1, 1, 1])
|
| 185 |
with col1:
|
| 186 |
+
left_class = st.selectbox("Left Class:", options=present_classes, index=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
with col2:
|
|
|
|
| 188 |
right_index = 1 if len(present_classes) > 1 else 0
|
| 189 |
+
right_class = st.selectbox("Right Class:", options=present_classes, index=right_index)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
with col3:
|
| 191 |
+
num_beats = st.number_input("Beats per class:", min_value=1, max_value=10, value=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
|
|
|
| 193 |
class_name_to_idx = {v: k for k, v in class_map.items()}
|
| 194 |
left_class_idx = class_name_to_idx[left_class]
|
| 195 |
right_class_idx = class_name_to_idx[right_class]
|
| 196 |
left_indices = np.where(predicted_classes == left_class_idx)[0]
|
| 197 |
right_indices = np.where(predicted_classes == right_class_idx)[0]
|
| 198 |
|
|
|
|
| 199 |
left_col, right_col = st.columns(2)
|
|
|
|
| 200 |
def display_class_beats(col, class_name, beat_indices, num_beats):
|
| 201 |
with col:
|
| 202 |
st.subheader(class_name)
|
|
|
|
| 205 |
return
|
| 206 |
|
| 207 |
for beat_idx in beat_indices[:num_beats]:
|
| 208 |
+
beat = beats[beat_idx].flatten()
|
| 209 |
sample = beat.reshape(1, -1, 1).astype(np.float32)
|
|
|
|
|
|
|
| 210 |
heatmap = generate_grad_cam(model, sample, conv_layer_name)
|
|
|
|
|
|
|
| 211 |
fig, ax = plt.subplots(figsize=(8, 2))
|
| 212 |
y_min, y_max = beat.min(), beat.max()
|
|
|
|
|
|
|
| 213 |
ax.imshow(
|
| 214 |
+
np.expand_dims(heatmap, axis=0),
|
| 215 |
aspect='auto',
|
| 216 |
cmap='jet',
|
| 217 |
alpha=0.5,
|
| 218 |
extent=[0, len(beat), y_min, y_max]
|
| 219 |
)
|
|
|
|
|
|
|
| 220 |
ax.plot(beat, linewidth=2, color='blue')
|
| 221 |
+
ax.axis('off')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
ax.set_title(f"Beat {beat_idx}")
|
| 223 |
ax.set_xlim(0, len(beat))
|
| 224 |
ax.set_ylim(y_min, y_max)
|
|
|
|
| 225 |
st.pyplot(fig)
|
|
|
|
| 226 |
display_class_beats(left_col, left_class, left_indices, num_beats)
|
|
|
|
|
|
|
| 227 |
display_class_beats(right_col, right_class, right_indices, num_beats)
|
| 228 |
|
|
|
|
| 229 |
if left_class == right_class:
|
| 230 |
st.info("Comparing different instances of the same class. Note: This shows intra-class variation.")
|