eyupipler commited on
Commit
96fa53f
·
verified ·
1 Parent(s): 8944117

Upload epilepsy_detection_model.py

Browse files
Files changed (1) hide show
  1. v1/epilepsy_detection_model.py +406 -0
v1/epilepsy_detection_model.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import ast
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, f1_score
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from tqdm import tqdm
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ print(f"Using device: {device}")
17
+
18
+ class EpilepsyDataset(Dataset):
19
+ def __init__(self, csv_path):
20
+ self.data = pd.read_csv(csv_path)
21
+
22
+ def __len__(self):
23
+ return len(self.data)
24
+
25
+ def __getitem__(self, idx):
26
+ # Parse the data string to list
27
+ data_list = ast.literal_eval(self.data.iloc[idx]['data'])
28
+ data_tensor = torch.FloatTensor(data_list)
29
+ label = torch.LongTensor([self.data.iloc[idx]['label']])[0]
30
+ return data_tensor, label
31
+
32
+ class MultiHeadAttention(nn.Module):
33
+ def __init__(self, d_model, num_heads, dropout=0.1):
34
+ super().__init__()
35
+ assert d_model % num_heads == 0
36
+
37
+ self.d_model = d_model
38
+ self.num_heads = num_heads
39
+ self.d_k = d_model // num_heads
40
+
41
+ self.W_q = nn.Linear(d_model, d_model)
42
+ self.W_k = nn.Linear(d_model, d_model)
43
+ self.W_v = nn.Linear(d_model, d_model)
44
+ self.W_o = nn.Linear(d_model, d_model)
45
+
46
+ self.dropout = nn.Dropout(dropout)
47
+
48
+ def forward(self, x):
49
+ batch_size = x.size(0)
50
+
51
+ Q = self.W_q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
52
+ K = self.W_k(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
53
+ V = self.W_v(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
54
+
55
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
56
+ attn_weights = F.softmax(scores, dim=-1)
57
+ attn_weights = self.dropout(attn_weights)
58
+
59
+ context = torch.matmul(attn_weights, V)
60
+ context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
61
+
62
+ output = self.W_o(context)
63
+ return output, attn_weights
64
+
65
+ class TemporalConvBlock(nn.Module):
66
+ def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.2):
67
+ super().__init__()
68
+ padding = (kernel_size - 1) * dilation // 2
69
+
70
+ self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size,
71
+ padding=padding, dilation=dilation)
72
+ self.bn1 = nn.BatchNorm1d(out_channels)
73
+ self.relu1 = nn.ReLU()
74
+ self.dropout1 = nn.Dropout(dropout)
75
+
76
+ self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
77
+ padding=padding, dilation=dilation)
78
+ self.bn2 = nn.BatchNorm1d(out_channels)
79
+ self.relu2 = nn.ReLU()
80
+ self.dropout2 = nn.Dropout(dropout)
81
+
82
+ self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
83
+
84
+ def forward(self, x):
85
+ residual = x if self.downsample is None else self.downsample(x)
86
+
87
+ out = self.conv1(x)
88
+ out = self.bn1(out)
89
+ out = self.relu1(out)
90
+ out = self.dropout1(out)
91
+
92
+ out = self.conv2(out)
93
+ out = self.bn2(out)
94
+
95
+ out = out + residual
96
+ out = self.relu2(out)
97
+ out = self.dropout2(out)
98
+
99
+ return out
100
+
101
+ class ChannelAttention(nn.Module):
102
+ def __init__(self, channels, reduction=8):
103
+ super().__init__()
104
+ self.avg_pool = nn.AdaptiveAvgPool1d(1)
105
+ self.max_pool = nn.AdaptiveMaxPool1d(1)
106
+
107
+ self.fc = nn.Sequential(
108
+ nn.Linear(channels, channels // reduction, bias=False),
109
+ nn.ReLU(),
110
+ nn.Linear(channels // reduction, channels, bias=False),
111
+ nn.Sigmoid()
112
+ )
113
+
114
+ def forward(self, x):
115
+ b, c, _ = x.size()
116
+
117
+ avg_out = self.fc(self.avg_pool(x).view(b, c))
118
+ max_out = self.fc(self.max_pool(x).view(b, c))
119
+
120
+ out = avg_out + max_out
121
+ return x * out.view(b, c, 1)
122
+
123
+ class AdvancedEpilepsyDetector(nn.Module):
124
+ def __init__(self, input_dim=178, num_classes=2, dropout=0.3):
125
+ super().__init__()
126
+
127
+ self.input_proj = nn.Linear(input_dim, 256)
128
+ self.input_bn = nn.BatchNorm1d(256)
129
+
130
+ self.tcn_blocks = nn.ModuleList([
131
+ TemporalConvBlock(1, 64, kernel_size=7, dilation=1, dropout=dropout),
132
+ TemporalConvBlock(64, 128, kernel_size=5, dilation=2, dropout=dropout),
133
+ TemporalConvBlock(128, 256, kernel_size=3, dilation=4, dropout=dropout),
134
+ TemporalConvBlock(256, 256, kernel_size=3, dilation=8, dropout=dropout),
135
+ ])
136
+
137
+ self.channel_attn = ChannelAttention(256)
138
+
139
+ self.mha1 = MultiHeadAttention(256, num_heads=8, dropout=dropout)
140
+ self.mha2 = MultiHeadAttention(256, num_heads=8, dropout=dropout)
141
+
142
+ self.layer_norm1 = nn.LayerNorm(256)
143
+ self.layer_norm2 = nn.LayerNorm(256)
144
+
145
+ self.ffn = nn.Sequential(
146
+ nn.Linear(256, 512),
147
+ nn.ReLU(),
148
+ nn.Dropout(dropout),
149
+ nn.Linear(512, 256),
150
+ nn.Dropout(dropout)
151
+ )
152
+
153
+ self.bilstm = nn.LSTM(256, 128, num_layers=2, batch_first=True,
154
+ bidirectional=True, dropout=dropout)
155
+
156
+ self.classifier = nn.Sequential(
157
+ nn.Linear(256 + 256, 512), # TCN output + LSTM output
158
+ nn.BatchNorm1d(512),
159
+ nn.ReLU(),
160
+ nn.Dropout(dropout),
161
+
162
+ nn.Linear(512, 256),
163
+ nn.BatchNorm1d(256),
164
+ nn.ReLU(),
165
+ nn.Dropout(dropout),
166
+
167
+ nn.Linear(256, 128),
168
+ nn.BatchNorm1d(128),
169
+ nn.ReLU(),
170
+ nn.Dropout(dropout),
171
+
172
+ nn.Linear(128, num_classes)
173
+ )
174
+
175
+ def forward(self, x):
176
+ batch_size = x.size(0)
177
+
178
+ x_proj = self.input_proj(x)
179
+ x_proj = self.input_bn(x_proj)
180
+ x_proj = F.relu(x_proj)
181
+
182
+ x_tcn = x_proj.unsqueeze(1)
183
+ for tcn_block in self.tcn_blocks:
184
+ x_tcn = tcn_block(x_tcn)
185
+
186
+ x_tcn = self.channel_attn(x_tcn)
187
+ x_tcn = x_tcn.squeeze(1) if x_tcn.dim() == 3 and x_tcn.size(1) == 1 else x_tcn.mean(dim=-1)
188
+
189
+ x_trans = x_proj.unsqueeze(1)
190
+
191
+ attn_out1, _ = self.mha1(x_trans)
192
+ x_trans = self.layer_norm1(x_trans + attn_out1)
193
+
194
+ attn_out2, _ = self.mha2(x_trans)
195
+ x_trans = self.layer_norm2(x_trans + attn_out2)
196
+
197
+ ffn_out = self.ffn(x_trans)
198
+ x_trans = x_trans + ffn_out
199
+
200
+ lstm_out, _ = self.bilstm(x_trans)
201
+ lstm_out = lstm_out[:, -1, :]
202
+
203
+ combined = torch.cat([x_tcn, lstm_out], dim=1)
204
+
205
+ output = self.classifier(combined)
206
+
207
+ return output
208
+
209
+ class FocalLoss(nn.Module):
210
+ def __init__(self, alpha=0.25, gamma=2):
211
+ super().__init__()
212
+ self.alpha = alpha
213
+ self.gamma = gamma
214
+
215
+ def forward(self, inputs, targets):
216
+ ce_loss = F.cross_entropy(inputs, targets, reduction='none')
217
+ pt = torch.exp(-ce_loss)
218
+ focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
219
+ return focal_loss.mean()
220
+
221
+ def train_epoch(model, dataloader, criterion, optimizer, device):
222
+ model.train()
223
+ running_loss = 0.0
224
+ all_preds = []
225
+ all_labels = []
226
+
227
+ pbar = tqdm(dataloader, desc='Training')
228
+ for data, labels in pbar:
229
+ data, labels = data.to(device), labels.to(device)
230
+
231
+ optimizer.zero_grad()
232
+ outputs = model(data)
233
+ loss = criterion(outputs, labels)
234
+ loss.backward()
235
+
236
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
237
+
238
+ optimizer.step()
239
+
240
+ running_loss += loss.item()
241
+ _, preds = torch.max(outputs, 1)
242
+ all_preds.extend(preds.cpu().numpy())
243
+ all_labels.extend(labels.cpu().numpy())
244
+
245
+ pbar.set_postfix({'loss': loss.item()})
246
+
247
+ epoch_loss = running_loss / len(dataloader)
248
+ epoch_f1 = f1_score(all_labels, all_preds, average='weighted')
249
+
250
+ return epoch_loss, epoch_f1
251
+
252
+ def validate(model, dataloader, criterion, device):
253
+ model.eval()
254
+ running_loss = 0.0
255
+ all_preds = []
256
+ all_labels = []
257
+ all_probs = []
258
+
259
+ with torch.no_grad():
260
+ for data, labels in tqdm(dataloader, desc='Validation'):
261
+ data, labels = data.to(device), labels.to(device)
262
+
263
+ outputs = model(data)
264
+ loss = criterion(outputs, labels)
265
+
266
+ running_loss += loss.item()
267
+ probs = F.softmax(outputs, dim=1)
268
+ _, preds = torch.max(outputs, 1)
269
+
270
+ all_preds.extend(preds.cpu().numpy())
271
+ all_labels.extend(labels.cpu().numpy())
272
+ all_probs.extend(probs.cpu().numpy()[:, 1])
273
+
274
+ epoch_loss = running_loss / len(dataloader)
275
+ epoch_f1 = f1_score(all_labels, all_preds, average='weighted')
276
+ epoch_auc = roc_auc_score(all_labels, all_probs)
277
+
278
+ return epoch_loss, epoch_f1, epoch_auc, all_preds, all_labels
279
+
280
+ def main():
281
+ BATCH_SIZE = 64
282
+ LEARNING_RATE = 0.001
283
+ NUM_EPOCHS = 100
284
+ PATIENCE = 15
285
+
286
+ print("Loading datasets...")
287
+ train_dataset = EpilepsyDataset(r'train/path')
288
+ val_dataset = EpilepsyDataset(r'val/path')
289
+ test_dataset = EpilepsyDataset(r'test/path')
290
+
291
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
292
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
293
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
294
+
295
+ print(f"Train samples: {len(train_dataset)}")
296
+ print(f"Val samples: {len(val_dataset)}")
297
+ print(f"Test samples: {len(test_dataset)}")
298
+
299
+ model = AdvancedEpilepsyDetector(input_dim=178, num_classes=2, dropout=0.3).to(device)
300
+
301
+ total_params = sum(p.numel() for p in model.parameters())
302
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
303
+ print(f"\nTotal parameters: {total_params:,}")
304
+ print(f"Trainable parameters: {trainable_params:,}")
305
+
306
+ criterion = FocalLoss(alpha=0.25, gamma=2)
307
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
308
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
309
+ patience=5, verbose=True)
310
+
311
+ best_val_f1 = 0.0
312
+ patience_counter = 0
313
+ train_losses, val_losses = [], []
314
+ train_f1s, val_f1s = [], []
315
+
316
+ print("\nStarting training...\n")
317
+
318
+ for epoch in range(NUM_EPOCHS):
319
+ print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
320
+ print("-" * 50)
321
+
322
+ train_loss, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
323
+
324
+ val_loss, val_f1, val_auc, _, _ = validate(model, val_loader, criterion, device)
325
+
326
+ scheduler.step(val_loss)
327
+
328
+ train_losses.append(train_loss)
329
+ val_losses.append(val_loss)
330
+ train_f1s.append(train_f1)
331
+ val_f1s.append(val_f1)
332
+
333
+ print(f"Train Loss: {train_loss:.4f} | Train F1: {train_f1:.4f}")
334
+ print(f"Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f} | Val AUC: {val_auc:.4f}\n")
335
+
336
+ if val_f1 > best_val_f1:
337
+ best_val_f1 = val_f1
338
+ patience_counter = 0
339
+ torch.save({
340
+ 'epoch': epoch,
341
+ 'model_state_dict': model.state_dict(),
342
+ 'optimizer_state_dict': optimizer.state_dict(),
343
+ 'val_f1': val_f1,
344
+ 'val_auc': val_auc
345
+ }, r'best_epilepsy_model.pth')
346
+ print(f"[SAVED] Model saved with Val F1: {val_f1:.4f}\n")
347
+ else:
348
+ patience_counter += 1
349
+ if patience_counter >= PATIENCE:
350
+ print(f"\nEarly stopping triggered after {epoch+1} epochs")
351
+ break
352
+
353
+ print("\nLoading best model for testing...")
354
+ checkpoint = torch.load(r'best_epilepsy_model.pth')
355
+ model.load_state_dict(checkpoint['model_state_dict'])
356
+
357
+ print("\nEvaluating on test set...")
358
+ test_loss, test_f1, test_auc, test_preds, test_labels = validate(model, test_loader, criterion, device)
359
+
360
+ print(f"\nTest Results:")
361
+ print(f"Test Loss: {test_loss:.4f}")
362
+ print(f"Test F1: {test_f1:.4f}")
363
+ print(f"Test AUC: {test_auc:.4f}")
364
+
365
+ print("\nClassification Report:")
366
+ print(classification_report(test_labels, test_preds, target_names=['Non-Seizure', 'Seizure']))
367
+
368
+ # Confusion matrix
369
+ cm = confusion_matrix(test_labels, test_preds)
370
+ plt.figure(figsize=(10, 8))
371
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
372
+ xticklabels=['Non-Seizure', 'Seizure'],
373
+ yticklabels=['Non-Seizure', 'Seizure'])
374
+ plt.title('Confusion Matrix - Epilepsy Detection')
375
+ plt.ylabel('True Label')
376
+ plt.xlabel('Predicted Label')
377
+ plt.savefig(r'confusion_matrix.png', dpi=300, bbox_inches='tight')
378
+ plt.close()
379
+
380
+ plt.figure(figsize=(15, 5))
381
+
382
+ plt.subplot(1, 2, 1)
383
+ plt.plot(train_losses, label='Train Loss')
384
+ plt.plot(val_losses, label='Val Loss')
385
+ plt.xlabel('Epoch')
386
+ plt.ylabel('Loss')
387
+ plt.title('Training and Validation Loss')
388
+ plt.legend()
389
+ plt.grid(True)
390
+
391
+ plt.subplot(1, 2, 2)
392
+ plt.plot(train_f1s, label='Train F1')
393
+ plt.plot(val_f1s, label='Val F1')
394
+ plt.xlabel('Epoch')
395
+ plt.ylabel('F1 Score')
396
+ plt.title('Training and Validation F1 Score')
397
+ plt.legend()
398
+ plt.grid(True)
399
+
400
+ plt.savefig(r'training_curves.png', dpi=300, bbox_inches='tight')
401
+ plt.close()
402
+
403
+ print("\nTraining completed! Results saved.")
404
+
405
+ if __name__ == "__main__":
406
+ main()