lukiod commited on
Commit
6de8626
·
verified ·
1 Parent(s): 0a632d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -79
app.py CHANGED
@@ -54,7 +54,6 @@ def extract_beats(record, annotation, window_size=257):
54
  # Function to detect the last Conv1D layer in the model
55
  def get_last_conv_layer_name(model):
56
  last_conv_layer = None
57
- # Loop in reverse order over layers to find a Conv1D layer
58
  for layer in reversed(model.layers):
59
  if isinstance(layer, tf.keras.layers.Conv1D):
60
  last_conv_layer = layer.name
@@ -65,13 +64,6 @@ def get_last_conv_layer_name(model):
65
 
66
  # Function to generate Grad-CAM heatmap for a given beat and class index
67
  def generate_grad_cam(model, sample, layer_name):
68
- """
69
- model : your loaded Keras model
70
- sample : a 4D tensor of shape (1, window_size, 1)
71
- layer_name : name of the Conv1D layer to use for Grad‑CAM
72
- returns : 1D numpy heatmap of length window_size
73
- """
74
- # Build a model that returns both the conv outputs and the predictions
75
  grad_model = tf.keras.models.Model(
76
  inputs=model.inputs,
77
  outputs=[model.get_layer(layer_name).output, model.output]
@@ -79,45 +71,40 @@ def generate_grad_cam(model, sample, layer_name):
79
 
80
  with tf.GradientTape() as tape:
81
  conv_outputs, predictions = grad_model(sample)
82
- # pick the top predicted class
83
  class_idx = tf.argmax(predictions[0])
84
  loss = predictions[:, class_idx]
85
 
86
- # gradient of the loss wrt conv outputs
87
  grads = tape.gradient(loss, conv_outputs)
88
-
89
- # global average pool the gradients to get the importance of each channel
90
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1)) # shape = (channels,)
91
-
92
- # remove batch dim from conv_outputs -> (time, channels)
93
  conv_outputs = tf.squeeze(conv_outputs, axis=0)
94
-
95
- # weight the conv outputs by the pooled gradients
96
- cam = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1) # shape = (time,)
97
  raw = cam.numpy()
98
  print("raw min/max:", raw.min(), raw.max())
99
 
100
- cam = tf.abs(cam) # ReLU
101
- cam = cam / (tf.reduce_max(cam) + 1e-8) # normalize
102
 
103
  return cam.numpy()
104
 
105
-
 
 
 
 
 
 
106
 
107
  # Streamlit App Layout
108
  st.title("ECG Arrhythmia Classification with Grad-CAM Visualization")
109
  st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
110
- record_loaded = False
111
- record = None
112
- annotation = None
113
 
114
  # Load Record 108 Button
115
  if st.button("Load Record 108"):
116
  try:
117
  base_name = "108"
118
- record = wfdb.rdrecord(base_name)
119
- annotation = wfdb.rdann(base_name, 'atr')
120
- record_loaded = True
121
  except Exception as e:
122
  st.error(f"Error loading Record 108: {str(e)}")
123
 
@@ -128,7 +115,7 @@ uploaded_files = st.file_uploader(
128
  accept_multiple_files=True
129
  )
130
 
131
- if uploaded_files and not record_loaded:
132
  with tempfile.TemporaryDirectory() as tmpdir:
133
  for f in uploaded_files:
134
  file_path = os.path.join(tmpdir, f.name)
@@ -137,15 +124,15 @@ if uploaded_files and not record_loaded:
137
  base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
138
  common_base = list(base_names)[0]
139
  try:
140
- record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
141
- annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
142
- record_loaded = True
143
  except Exception as e:
144
  st.error(f"Error reading uploaded files: {str(e)}")
145
 
146
  # Process the record if loaded
147
- if record_loaded and record is not None and annotation is not None:
148
- beats = extract_beats(record, annotation)
149
  if len(beats) == 0:
150
  st.error("No valid beats found in the record")
151
  st.stop()
@@ -162,7 +149,6 @@ if record_loaded and record is not None and annotation is not None:
162
  })
163
  st.dataframe(results)
164
 
165
- # Class Distribution Section
166
  st.subheader("Class Distribution")
167
  class_indices = list(class_map.keys())
168
  class_names = [class_map[i] for i in class_indices]
@@ -177,62 +163,40 @@ if record_loaded and record is not None and annotation is not None:
177
  with col2:
178
  st.bar_chart(distribution_df.set_index('Class'))
179
 
180
- # Display a Sample ECG Beat
181
  st.subheader("Sample ECG Beat")
182
  fig, ax = plt.subplots()
183
  ax.plot(beats[0].flatten(), label="ECG Beat")
184
  ax.legend()
185
  st.pyplot(fig)
186
 
187
- # ---------------- Grad-CAM Visualization Section ----------------
188
  st.subheader("Class Comparison with Grad-CAM")
189
  st.write("Compare model explanations between classes present in this record")
190
 
191
- # Automatically detect the last convolutional layer name
192
  conv_layer_name = get_last_conv_layer_name(model)
193
  if conv_layer_name is not None:
194
  st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
195
 
196
- # Get classes actually present in the data
197
  present_classes = distribution_df[distribution_df['Count'] > 0]['Class'].tolist()
198
  if not present_classes:
199
  st.warning("No classes with detected beats to compare")
200
  st.stop()
201
 
202
- # Class selection dropdowns
203
  col1, col2, col3 = st.columns([1, 1, 1])
204
  with col1:
205
- left_class = st.selectbox(
206
- "Left Class:",
207
- options=present_classes,
208
- index=0
209
- )
210
  with col2:
211
- # Default to second class if available, else first
212
  right_index = 1 if len(present_classes) > 1 else 0
213
- right_class = st.selectbox(
214
- "Right Class:",
215
- options=present_classes,
216
- index=right_index
217
- )
218
  with col3:
219
- num_beats = st.number_input(
220
- "Beats per class:",
221
- min_value=1,
222
- max_value=10,
223
- value=3
224
- )
225
 
226
- # Get class indices from names
227
  class_name_to_idx = {v: k for k, v in class_map.items()}
228
  left_class_idx = class_name_to_idx[left_class]
229
  right_class_idx = class_name_to_idx[right_class]
230
  left_indices = np.where(predicted_classes == left_class_idx)[0]
231
  right_indices = np.where(predicted_classes == right_class_idx)[0]
232
 
233
- # Create comparison columns
234
  left_col, right_col = st.columns(2)
235
-
236
  def display_class_beats(col, class_name, beat_indices, num_beats):
237
  with col:
238
  st.subheader(class_name)
@@ -241,43 +205,26 @@ if record_loaded and record is not None and annotation is not None:
241
  return
242
 
243
  for beat_idx in beat_indices[:num_beats]:
244
- beat = beats[beat_idx].flatten() # shape (window_size,)
245
  sample = beat.reshape(1, -1, 1).astype(np.float32)
246
-
247
- # generate the 1D heatmap
248
  heatmap = generate_grad_cam(model, sample, conv_layer_name)
249
-
250
- # set up figure
251
  fig, ax = plt.subplots(figsize=(8, 2))
252
  y_min, y_max = beat.min(), beat.max()
253
-
254
- # Always draw the heatmap background for all beats
255
  ax.imshow(
256
- np.expand_dims(heatmap, axis=0), # shape (1, window_size)
257
  aspect='auto',
258
  cmap='jet',
259
  alpha=0.5,
260
  extent=[0, len(beat), y_min, y_max]
261
  )
262
-
263
- # overlay the ECG trace
264
  ax.plot(beat, linewidth=2, color='blue')
265
-
266
- # styling
267
- # Do NOT set a facecolor here - it will block the heatmap
268
- # ax.set_facecolor('#e0e0f0') # This line is commented out
269
- ax.axis('off') # clean look
270
  ax.set_title(f"Beat {beat_idx}")
271
  ax.set_xlim(0, len(beat))
272
  ax.set_ylim(y_min, y_max)
273
-
274
  st.pyplot(fig)
275
- # Display left class beats
276
  display_class_beats(left_col, left_class, left_indices, num_beats)
277
-
278
- # Display right class beats
279
  display_class_beats(right_col, right_class, right_indices, num_beats)
280
 
281
- # Add comparison note if same class selected
282
  if left_class == right_class:
283
  st.info("Comparing different instances of the same class. Note: This shows intra-class variation.")
 
54
  # Function to detect the last Conv1D layer in the model
55
  def get_last_conv_layer_name(model):
56
  last_conv_layer = None
 
57
  for layer in reversed(model.layers):
58
  if isinstance(layer, tf.keras.layers.Conv1D):
59
  last_conv_layer = layer.name
 
64
 
65
  # Function to generate Grad-CAM heatmap for a given beat and class index
66
  def generate_grad_cam(model, sample, layer_name):
 
 
 
 
 
 
 
67
  grad_model = tf.keras.models.Model(
68
  inputs=model.inputs,
69
  outputs=[model.get_layer(layer_name).output, model.output]
 
71
 
72
  with tf.GradientTape() as tape:
73
  conv_outputs, predictions = grad_model(sample)
 
74
  class_idx = tf.argmax(predictions[0])
75
  loss = predictions[:, class_idx]
76
 
 
77
  grads = tape.gradient(loss, conv_outputs)
78
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
 
 
 
 
79
  conv_outputs = tf.squeeze(conv_outputs, axis=0)
80
+ cam = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1)
 
 
81
  raw = cam.numpy()
82
  print("raw min/max:", raw.min(), raw.max())
83
 
84
+ cam = tf.abs(cam)
85
+ cam = cam / (tf.reduce_max(cam) + 1e-8)
86
 
87
  return cam.numpy()
88
 
89
+ # Initialize session state variables if not already set
90
+ if 'record_loaded' not in st.session_state:
91
+ st.session_state.record_loaded = False
92
+ if 'record' not in st.session_state:
93
+ st.session_state.record = None
94
+ if 'annotation' not in st.session_state:
95
+ st.session_state.annotation = None
96
 
97
  # Streamlit App Layout
98
  st.title("ECG Arrhythmia Classification with Grad-CAM Visualization")
99
  st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
 
 
 
100
 
101
  # Load Record 108 Button
102
  if st.button("Load Record 108"):
103
  try:
104
  base_name = "108"
105
+ st.session_state.record = wfdb.rdrecord(base_name)
106
+ st.session_state.annotation = wfdb.rdann(base_name, 'atr')
107
+ st.session_state.record_loaded = True
108
  except Exception as e:
109
  st.error(f"Error loading Record 108: {str(e)}")
110
 
 
115
  accept_multiple_files=True
116
  )
117
 
118
+ if uploaded_files and not st.session_state.record_loaded:
119
  with tempfile.TemporaryDirectory() as tmpdir:
120
  for f in uploaded_files:
121
  file_path = os.path.join(tmpdir, f.name)
 
124
  base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
125
  common_base = list(base_names)[0]
126
  try:
127
+ st.session_state.record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
128
+ st.session_state.annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
129
+ st.session_state.record_loaded = True
130
  except Exception as e:
131
  st.error(f"Error reading uploaded files: {str(e)}")
132
 
133
  # Process the record if loaded
134
+ if st.session_state.record_loaded and st.session_state.record is not None and st.session_state.annotation is not None:
135
+ beats = extract_beats(st.session_state.record, st.session_state.annotation)
136
  if len(beats) == 0:
137
  st.error("No valid beats found in the record")
138
  st.stop()
 
149
  })
150
  st.dataframe(results)
151
 
 
152
  st.subheader("Class Distribution")
153
  class_indices = list(class_map.keys())
154
  class_names = [class_map[i] for i in class_indices]
 
163
  with col2:
164
  st.bar_chart(distribution_df.set_index('Class'))
165
 
 
166
  st.subheader("Sample ECG Beat")
167
  fig, ax = plt.subplots()
168
  ax.plot(beats[0].flatten(), label="ECG Beat")
169
  ax.legend()
170
  st.pyplot(fig)
171
 
 
172
  st.subheader("Class Comparison with Grad-CAM")
173
  st.write("Compare model explanations between classes present in this record")
174
 
 
175
  conv_layer_name = get_last_conv_layer_name(model)
176
  if conv_layer_name is not None:
177
  st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
178
 
 
179
  present_classes = distribution_df[distribution_df['Count'] > 0]['Class'].tolist()
180
  if not present_classes:
181
  st.warning("No classes with detected beats to compare")
182
  st.stop()
183
 
 
184
  col1, col2, col3 = st.columns([1, 1, 1])
185
  with col1:
186
+ left_class = st.selectbox("Left Class:", options=present_classes, index=0)
 
 
 
 
187
  with col2:
 
188
  right_index = 1 if len(present_classes) > 1 else 0
189
+ right_class = st.selectbox("Right Class:", options=present_classes, index=right_index)
 
 
 
 
190
  with col3:
191
+ num_beats = st.number_input("Beats per class:", min_value=1, max_value=10, value=3)
 
 
 
 
 
192
 
 
193
  class_name_to_idx = {v: k for k, v in class_map.items()}
194
  left_class_idx = class_name_to_idx[left_class]
195
  right_class_idx = class_name_to_idx[right_class]
196
  left_indices = np.where(predicted_classes == left_class_idx)[0]
197
  right_indices = np.where(predicted_classes == right_class_idx)[0]
198
 
 
199
  left_col, right_col = st.columns(2)
 
200
  def display_class_beats(col, class_name, beat_indices, num_beats):
201
  with col:
202
  st.subheader(class_name)
 
205
  return
206
 
207
  for beat_idx in beat_indices[:num_beats]:
208
+ beat = beats[beat_idx].flatten()
209
  sample = beat.reshape(1, -1, 1).astype(np.float32)
 
 
210
  heatmap = generate_grad_cam(model, sample, conv_layer_name)
 
 
211
  fig, ax = plt.subplots(figsize=(8, 2))
212
  y_min, y_max = beat.min(), beat.max()
 
 
213
  ax.imshow(
214
+ np.expand_dims(heatmap, axis=0),
215
  aspect='auto',
216
  cmap='jet',
217
  alpha=0.5,
218
  extent=[0, len(beat), y_min, y_max]
219
  )
 
 
220
  ax.plot(beat, linewidth=2, color='blue')
221
+ ax.axis('off')
 
 
 
 
222
  ax.set_title(f"Beat {beat_idx}")
223
  ax.set_xlim(0, len(beat))
224
  ax.set_ylim(y_min, y_max)
 
225
  st.pyplot(fig)
 
226
  display_class_beats(left_col, left_class, left_indices, num_beats)
 
 
227
  display_class_beats(right_col, right_class, right_indices, num_beats)
228
 
 
229
  if left_class == right_class:
230
  st.info("Comparing different instances of the same class. Note: This shows intra-class variation.")