lukiod commited on
Commit
28e701d
·
verified ·
1 Parent(s): fd3c0db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -15
app.py CHANGED
@@ -7,6 +7,7 @@ import tempfile
7
  import os
8
  from scipy.signal import resample
9
  import matplotlib.pyplot as plt
 
10
 
11
  # Custom activation functions
12
  def sin_activation(x):
@@ -37,11 +38,11 @@ class_map = {
37
  4: "Unknown"
38
  }
39
 
 
40
  def extract_beats(record, annotation, window_size=257):
41
  beats = []
42
  r_locs = annotation.sample
43
  signal = record.p_signal[:, 0] # Using first channel
44
-
45
  for r in r_locs:
46
  start = max(0, r - window_size//2)
47
  end = min(len(signal), r + window_size//2 + 1)
@@ -50,7 +51,54 @@ def extract_beats(record, annotation, window_size=257):
50
  beats.append(beat)
51
  return np.array(beats)
52
 
53
- st.title("ECG Arrhythmia Classification")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
55
  record_loaded = False
56
  record = None
@@ -79,10 +127,8 @@ if uploaded_files and not record_loaded:
79
  file_path = os.path.join(tmpdir, f.name)
80
  with open(file_path, "wb") as f_out:
81
  f_out.write(f.getbuffer())
82
-
83
  base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
84
  common_base = list(base_names)[0]
85
-
86
  try:
87
  record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
88
  annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
@@ -90,13 +136,13 @@ if uploaded_files and not record_loaded:
90
  except Exception as e:
91
  st.error(f"Error reading uploaded files: {str(e)}")
92
 
93
- # Run processing if record is loaded
94
  if record_loaded and record is not None and annotation is not None:
95
  beats = extract_beats(record, annotation)
96
  if len(beats) == 0:
97
  st.error("No valid beats found in the record")
98
  st.stop()
99
-
100
  beats = beats.reshape((-1, 257, 1)).astype(np.float32)
101
  predictions = model.predict(beats)
102
  predicted_classes = np.argmax(predictions, axis=1)
@@ -111,27 +157,60 @@ if record_loaded and record is not None and annotation is not None:
111
 
112
  # Class Distribution Section
113
  st.subheader("Class Distribution")
114
-
115
- # Get counts for all classes
116
  class_indices = list(class_map.keys())
117
  class_names = [class_map[i] for i in class_indices]
118
  counts = [np.sum(predicted_classes == i) for i in class_indices]
119
-
120
- # Create distribution dataframe
121
  distribution_df = pd.DataFrame({
122
  "Class": class_names,
123
  "Count": counts
124
  })
125
-
126
- # Display in two columns
127
  col1, col2 = st.columns([1, 2])
128
  with col1:
129
  st.dataframe(distribution_df.style.format({'Count': '{:,}'}))
130
-
131
  with col2:
132
  st.bar_chart(distribution_df.set_index('Class'))
133
 
 
134
  st.subheader("Sample ECG Beat")
135
  fig, ax = plt.subplots()
136
- ax.plot(beats[0].flatten())
137
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import os
8
  from scipy.signal import resample
9
  import matplotlib.pyplot as plt
10
+ import cv2
11
 
12
  # Custom activation functions
13
  def sin_activation(x):
 
38
  4: "Unknown"
39
  }
40
 
41
+ # Function to extract beats from record
42
  def extract_beats(record, annotation, window_size=257):
43
  beats = []
44
  r_locs = annotation.sample
45
  signal = record.p_signal[:, 0] # Using first channel
 
46
  for r in r_locs:
47
  start = max(0, r - window_size//2)
48
  end = min(len(signal), r + window_size//2 + 1)
 
51
  beats.append(beat)
52
  return np.array(beats)
53
 
54
+ # Function to detect the last Conv1D layer in the model
55
+ def get_last_conv_layer_name(model):
56
+ last_conv_layer = None
57
+ # Loop in reverse order over layers to find a Conv1D layer
58
+ for layer in reversed(model.layers):
59
+ if isinstance(layer, tf.keras.layers.Conv1D):
60
+ last_conv_layer = layer.name
61
+ break
62
+ if last_conv_layer is None:
63
+ st.error("No Conv1D layer found in the model. Grad-CAM requires a convolution layer.")
64
+ return last_conv_layer
65
+
66
+ # Function to generate Grad-CAM heatmap for a given beat and class index
67
+ def make_gradcam_heatmap(beat, model, conv_layer_name, class_index):
68
+ # Create a model that maps the input beat to the activations of the conv layer and the output predictions
69
+ grad_model = tf.keras.models.Model(
70
+ [model.inputs],
71
+ [model.get_layer(conv_layer_name).output, model.output]
72
+ )
73
+ # Record operations for automatic differentiation
74
+ with tf.GradientTape() as tape:
75
+ # Expand dims to add batch axis: shape (1, 257, 1)
76
+ beat_tensor = tf.expand_dims(beat, axis=0)
77
+ conv_outputs, predictions = grad_model(beat_tensor)
78
+ loss = predictions[:, class_index]
79
+ # Compute gradients of the target class wrt feature map
80
+ grads = tape.gradient(loss, conv_outputs)
81
+ # Global average pooling over the time dimension to get weights
82
+ weights = tf.reduce_mean(grads, axis=1)
83
+ # Compute the weighted sum of feature maps along the channel dimension
84
+ cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1)
85
+ cam = tf.squeeze(cam) # Remove batch dimension
86
+ # Apply ReLU to the heatmap to keep only positive influences
87
+ heatmap = tf.maximum(cam, 0)
88
+ # Normalize heatmap to the [0, 1] range
89
+ heatmap_max = tf.reduce_max(heatmap)
90
+ if heatmap_max == 0:
91
+ heatmap = tf.zeros_like(heatmap)
92
+ else:
93
+ heatmap /= heatmap_max
94
+ heatmap = heatmap.numpy()
95
+ # Resize heatmap to match the input beat size (if needed)
96
+ # For 1D, we use cv2.resize with the new shape (length, 1) then flatten
97
+ heatmap = cv2.resize(heatmap, (beat.shape[0], 1)).flatten()
98
+ return heatmap
99
+
100
+ # Streamlit App Layout
101
+ st.title("ECG Arrhythmia Classification with Grad-CAM Visualization")
102
  st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
103
  record_loaded = False
104
  record = None
 
127
  file_path = os.path.join(tmpdir, f.name)
128
  with open(file_path, "wb") as f_out:
129
  f_out.write(f.getbuffer())
 
130
  base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
131
  common_base = list(base_names)[0]
 
132
  try:
133
  record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
134
  annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
 
136
  except Exception as e:
137
  st.error(f"Error reading uploaded files: {str(e)}")
138
 
139
+ # Process the record if loaded
140
  if record_loaded and record is not None and annotation is not None:
141
  beats = extract_beats(record, annotation)
142
  if len(beats) == 0:
143
  st.error("No valid beats found in the record")
144
  st.stop()
145
+
146
  beats = beats.reshape((-1, 257, 1)).astype(np.float32)
147
  predictions = model.predict(beats)
148
  predicted_classes = np.argmax(predictions, axis=1)
 
157
 
158
  # Class Distribution Section
159
  st.subheader("Class Distribution")
 
 
160
  class_indices = list(class_map.keys())
161
  class_names = [class_map[i] for i in class_indices]
162
  counts = [np.sum(predicted_classes == i) for i in class_indices]
 
 
163
  distribution_df = pd.DataFrame({
164
  "Class": class_names,
165
  "Count": counts
166
  })
 
 
167
  col1, col2 = st.columns([1, 2])
168
  with col1:
169
  st.dataframe(distribution_df.style.format({'Count': '{:,}'}))
 
170
  with col2:
171
  st.bar_chart(distribution_df.set_index('Class'))
172
 
173
+ # Display a Sample ECG Beat
174
  st.subheader("Sample ECG Beat")
175
  fig, ax = plt.subplots()
176
+ ax.plot(beats[0].flatten(), label="ECG Beat")
177
+ ax.legend()
178
+ st.pyplot(fig)
179
+
180
+ # ---------------- Grad-CAM Visualization Section ----------------
181
+ st.subheader("Grad-CAM Heatmap Visualization for Each Beat")
182
+ st.write("Below are Grad-CAM heatmaps overlaying each beat. The heatmaps show the regions contributing most to the predicted class.")
183
+
184
+ # Automatically detect the last convolutional layer name
185
+ conv_layer_name = get_last_conv_layer_name(model)
186
+ if conv_layer_name is not None:
187
+ st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.")
188
+
189
+ # Optionally, you can limit the number of beats displayed to avoid long processing times.
190
+ # For demonstration, here we process all beats, but you might want to show only the first N beats.
191
+ show_all = st.checkbox("Show Grad-CAM for all beats", value=False)
192
+ if not show_all:
193
+ num_beats_to_show = st.number_input("Number of beats to show:", min_value=1, max_value=len(beats), value=5)
194
+ else:
195
+ num_beats_to_show = len(beats)
196
+
197
+ # Loop over each beat and its prediction to generate Grad-CAM heatmap
198
+ for idx in range(num_beats_to_show):
199
+ beat = beats[idx]
200
+ pred_class = predicted_classes[idx]
201
+ predicted_label = class_map[pred_class]
202
+ # Compute Grad-CAM heatmap for the beat
203
+ heatmap = make_gradcam_heatmap(beat, model, conv_layer_name, pred_class)
204
+
205
+ # Generate visualization figure
206
+ fig, ax = plt.subplots(figsize=(10, 3))
207
+ # Plot the raw ECG beat
208
+ ax.plot(beat.flatten(), color="black", label="ECG Beat")
209
+ # Overlay Grad-CAM heatmap by scatter plotting points with a colormap according to heatmap value
210
+ sc = ax.scatter(np.arange(len(beat)), beat.flatten(), c=heatmap, cmap="jet", s=25)
211
+ ax.set_title(f"Beat {idx} - Predicted: {predicted_label}")
212
+ ax.set_xlabel("Time Index")
213
+ ax.set_ylabel("Amplitude")
214
+ # Add a colorbar to indicate heatmap intensity
215
+ fig.colorbar(sc, ax=ax, label="Grad-CAM Intensity")
216
+ st.pyplot(fig)