yw-Hua commited on
Commit
6a95667
Β·
1 Parent(s): 1a61527

Update codes

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
codes/.DS_Store ADDED
Binary file (6.15 kB). View file
 
codes/Fine-tuning/.DS_Store ADDED
Binary file (6.15 kB). View file
 
codes/Fine-tuning/cell_type_classification/NuSPIRe_from_scratch.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ import pandas as pd
7
+ from torchvision import transforms
8
+ from torch.utils.data import DataLoader, SubsetRandomSampler
9
+ from tqdm import tqdm
10
+ from transformers import ViTMAEConfig, ViTForImageClassification
11
+ from torchvision.datasets import ImageFolder
12
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, roc_auc_score, recall_score, confusion_matrix
13
+ from sklearn.preprocessing import label_binarize
14
+ from torch.optim.lr_scheduler import LambdaLR
15
+ import argparse
16
+
17
+ def set_seeds(seed_value=42, cuda_deterministic=False):
18
+ """Set seeds for reproducibility."""
19
+ random.seed(seed_value)
20
+ os.environ['PYTHONHASHSEED'] = str(seed_value)
21
+ np.random.seed(seed_value)
22
+ torch.manual_seed(seed_value)
23
+ if torch.cuda.is_available():
24
+ torch.cuda.manual_seed(seed_value)
25
+ torch.cuda.manual_seed_all(seed_value)
26
+ # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
27
+ if cuda_deterministic: # slower, more reproducible
28
+ torch.backends.cudnn.deterministic = True
29
+ torch.backends.cudnn.benchmark = False
30
+ else: # faster, less reproducible
31
+ torch.backends.cudnn.deterministic = False
32
+ torch.backends.cudnn.benchmark = True
33
+
34
+
35
+ def warmup_lr_lambda(current_epoch: int, warmup_epochs: int):
36
+ if (current_epoch < warmup_epochs):
37
+ return float(current_epoch + 1) / float(max(1, warmup_epochs))
38
+ return 1.0
39
+
40
+ # set up
41
+ parser = argparse.ArgumentParser(description="Setup experiment parameters")
42
+ parser.add_argument('--num', type=int, required=True, help='Number of samples per class')
43
+ parser.add_argument('--device', type=int, default=0, help='CUDA device number (default: 0)')
44
+ parser.add_argument('--rep', type=int, required=True, help='Number of replicate')
45
+ args = parser.parse_args()
46
+ num_samples_per_class = args.num
47
+ device = args.device
48
+ num_repeats = args.rep
49
+
50
+ SEED = 42
51
+ DEVICE = torch.device(f"cuda:{device}")
52
+ DATA_DIR = '../lung5_rep1_cancer_nuclear_image_15micron/'
53
+ BATCH_SIZE = 300
54
+ NUM_EPOCHS = 30
55
+ PORJECT_NAME = f'Nuspire_{num_samples_per_class}_r{num_repeats}_lung5_rep1_Classification'
56
+ set_seeds(SEED)
57
+ folder_name = f'./{PORJECT_NAME}_checkpoint'
58
+
59
+ if not os.path.exists(folder_name):
60
+ os.makedirs(folder_name)
61
+ print(f"'{folder_name}'has been created.")
62
+ else:
63
+ print(f"'{folder_name}' already exists.")
64
+
65
+ # Dataset
66
+ transform = transforms.Compose([
67
+ transforms.Resize((112, 112)),
68
+ transforms.Grayscale(),
69
+ transforms.RandomHorizontalFlip(p=0.5),
70
+ transforms.RandomVerticalFlip(p=0.5),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
73
+ ])
74
+
75
+ dataset = ImageFolder(DATA_DIR, transform=transform)
76
+ labels = [dataset[i][1] for i in range(len(dataset))]
77
+
78
+ # Define train and test sizes
79
+ train_size = int(0.8 * len(dataset))
80
+ valid_size = int(0.1 * len(dataset))
81
+ test_size = len(dataset) - train_size - valid_size
82
+
83
+ indices = np.arange(len(dataset))
84
+ np.random.shuffle(indices)
85
+
86
+ # Split
87
+ train_indices = indices[:train_size]
88
+ valid_indices = indices[train_size:train_size + valid_size]
89
+ test_indices = indices[train_size + valid_size:]
90
+ class_1_train_indices = [i for i in train_indices if labels[i] == 1]
91
+ class_2_train_indices = [i for i in train_indices if labels[i] == 2]
92
+ class_0_train_indices = [i for i in train_indices if labels[i] == 0]
93
+
94
+ for repeat in range(num_repeats):
95
+ np.random.shuffle(class_1_train_indices)
96
+ np.random.shuffle(class_2_train_indices)
97
+ np.random.shuffle(class_0_train_indices)
98
+
99
+ class_1_train_indices = class_1_train_indices[:num_samples_per_class]
100
+ class_2_train_indices = class_2_train_indices[:num_samples_per_class]
101
+ class_0_train_indices = class_0_train_indices[:num_samples_per_class]
102
+
103
+ balanced_train_indices = (
104
+ class_1_train_indices +
105
+ class_2_train_indices +
106
+ class_0_train_indices
107
+ )
108
+ np.random.shuffle(balanced_train_indices)
109
+
110
+ train_sampler = SubsetRandomSampler(balanced_train_indices)
111
+ valid_sampler = SubsetRandomSampler(valid_indices)
112
+ test_sampler = SubsetRandomSampler(test_indices)
113
+
114
+ # print(balanced_train_indices)
115
+ # print(valid_indices)
116
+ # print(test_indices)
117
+
118
+ train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers= 4)
119
+ valid_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=valid_sampler, num_workers= 4)
120
+ test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler, num_workers= 4)
121
+
122
+
123
+ config_path = "/mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69/config.json"
124
+ config = ViTMAEConfig.from_json_file(config_path)
125
+ config.architectures = ["ViTForImageClassification"]
126
+ config.num_labels = 3
127
+ config.image_size = 112
128
+ config.num_channels = 1
129
+ model = ViTForImageClassification(config)
130
+ model.to(DEVICE)
131
+
132
+ # Training
133
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
134
+ writer = SummaryWriter(f'./tensorboard/{PORJECT_NAME}')
135
+ step1 = 0
136
+ step2 = 0
137
+ best_val_loss = float('inf')
138
+ best_val_f1 = 0
139
+ warmup_epochs = 5
140
+ scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_lr_lambda(epoch, warmup_epochs))
141
+
142
+ for epoch in range(NUM_EPOCHS):
143
+ print(f"Epoch: {epoch+1}/{NUM_EPOCHS}")
144
+ model.train()
145
+ train_preds, train_labels = [], []
146
+ loss_list = []
147
+ for i, (x, l) in tqdm(enumerate(train_loader), total=len(train_loader)):
148
+ x = x.to(DEVICE)
149
+ l = l.to(DEVICE)
150
+
151
+ print(f"Input shape: {x.shape}")
152
+ print(f"Label shape: {l.shape}")
153
+
154
+ optimizer.zero_grad()
155
+
156
+ outputs = model(x, labels=l)
157
+
158
+ loss = outputs.loss
159
+
160
+ _, predicted = torch.max(outputs.logits, 1)
161
+ train_preds.extend(predicted.cpu().numpy())
162
+ train_labels.extend(l.cpu().numpy())
163
+
164
+ writer.add_scalar("Step/Train Loss", loss.item(), step1)
165
+ loss_list.append(loss.item())
166
+
167
+ step1 += 1
168
+ loss.backward()
169
+ optimizer.step()
170
+
171
+ train_loss = np.mean(loss_list)
172
+ train_accuracy = 100 * (np.array(train_preds) == np.array(train_labels)).mean()
173
+ train_f1 = f1_score(train_labels, train_preds, average='macro')
174
+ train_precision = precision_score(train_labels, train_preds, average='macro')
175
+
176
+ model.eval()
177
+ val_preds, val_labels = [], []
178
+ loss_list = []
179
+ with torch.no_grad():
180
+ for i, (x, l) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
181
+ x = x.to(DEVICE)
182
+ l = l.to(DEVICE)
183
+
184
+ outputs = model(x, labels=l)
185
+
186
+ loss = outputs.loss
187
+
188
+ _, predicted = torch.max(outputs.logits, 1)
189
+ val_preds.extend(predicted.cpu().numpy())
190
+ val_labels.extend(l.cpu().numpy())
191
+
192
+ writer.add_scalar("Step/Validation Loss", loss.item(), step2)
193
+
194
+ loss_list.append(loss.item())
195
+ step2 += 1
196
+ val_loss = np.mean(loss_list)
197
+ val_accuracy = 100 * (np.array(val_preds) == np.array(val_labels)).mean()
198
+ val_f1 = f1_score(val_labels, val_preds, average='macro')
199
+ val_precision = precision_score(val_labels, val_preds, average='macro')
200
+
201
+ val_labels_bin = label_binarize(val_labels, classes=[0, 1, 2])
202
+ val_preds_bin = label_binarize(val_preds, classes=[0, 1, 2])
203
+ val_auc = roc_auc_score(val_labels_bin, val_preds_bin, average='macro', multi_class='ovr')
204
+
205
+ # Save the model if the validation loss is the best we've seen so far.
206
+ if val_loss < best_val_loss:
207
+ torch.save(model.state_dict(), f'{folder_name}/{PORJECT_NAME}_best_loss_model.pt')
208
+ model.save_pretrained(f'{folder_name}/{PORJECT_NAME}_best_loss_model')
209
+ best_val_loss = val_loss
210
+
211
+ # Save the model if the validation F1 score is the best we've seen so far.
212
+ if val_f1 > best_val_f1:
213
+ torch.save(model.state_dict(), f'{folder_name}/{PORJECT_NAME}_best_f1_model.pt')
214
+ model.save_pretrained(f'{folder_name}/{PORJECT_NAME}_best_f1_model')
215
+ best_val_f1 = val_f1
216
+
217
+ lr = optimizer.param_groups[0]['lr']
218
+ writer.add_scalar("Epoch/Lr", lr, epoch)
219
+ writer.add_scalar("Epoch/Validation ROC AUC", val_auc, epoch)
220
+ writer.add_scalars("Epoch/Loss", {'Train Loss': train_loss, 'Validation Loss': val_loss}, epoch)
221
+ writer.add_scalars("Epoch/ACC", {'Train ACC': train_accuracy, 'Validation ACC': val_accuracy}, epoch)
222
+ writer.add_scalars("Epoch/Precision", {'Train Precision': train_precision, 'Validation Precision': val_precision}, epoch)
223
+ writer.add_scalars("Epoch/F1_Score", {'Train F1 Score': train_f1, 'Validation F1 Score': val_f1}, epoch)
224
+
225
+ print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Train ACC: {train_accuracy:.4f}%, Train F1: {train_f1:.4f}, Train Precision: {train_precision:.4f}")
226
+ print(f"Epoch {epoch}, Validation Loss: {val_loss:.4f}, Validation ACC: {val_accuracy:.4f}%, Validation F1: {val_f1:.4f}, Validation Precision: {val_precision:.4f}, Validation ROC AUC: {val_auc:.4f}")
227
+
228
+ scheduler.step()
229
+
230
+ # Test with best f1 model
231
+ transform = transforms.Compose([
232
+ transforms.Resize((112, 112)),
233
+ transforms.Grayscale(),
234
+ # transforms.RandomHorizontalFlip(p=0.5),
235
+ # transforms.RandomVerticalFlip(p=0.5),
236
+ transforms.ToTensor(),
237
+ transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
238
+ ])
239
+
240
+ dataset = ImageFolder(DATA_DIR, transform=transform)
241
+ test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler)
242
+ model_path = f'{folder_name}/{PORJECT_NAME}_best_f1_model.pt'
243
+ model.load_state_dict(torch.load(model_path))
244
+ model.to(DEVICE)
245
+ model.eval()
246
+ test_preds, test_labels = [], []
247
+ test_probs = []
248
+
249
+ with torch.no_grad():
250
+ for x, l in tqdm(test_loader, total=len(test_loader)):
251
+ x = x.to(DEVICE)
252
+ l = l.to(DEVICE)
253
+
254
+ outputs = model(x)
255
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
256
+ _, predicted = torch.max(probabilities, 1)
257
+
258
+ test_preds.extend(predicted.cpu().numpy())
259
+ test_labels.extend(l.cpu().numpy())
260
+ test_probs.extend(probabilities.cpu().numpy())
261
+
262
+ test_probs = np.array(test_probs)
263
+
264
+ df = pd.DataFrame({
265
+ 'True Labels': test_labels,
266
+ 'Predicted Labels': test_preds
267
+ })
268
+
269
+ for i in range(test_probs.shape[1]):
270
+ df[f'Prob_Class{i}'] = test_probs[:, i]
271
+
272
+ df.to_csv(f'{PORJECT_NAME}.csv', index=False)
273
+ print("Test labels, predictions, and probabilities have been saved")
274
+
275
+ test_labels_binarized = label_binarize(test_labels, classes=[0, 1, 2])
276
+ test_preds_binarized = label_binarize(test_preds, classes=[0, 1, 2])
277
+
278
+ accuracy = accuracy_score(test_labels, test_preds)
279
+ f1 = f1_score(test_labels, test_preds, average='macro')
280
+ precision = precision_score(test_labels, test_preds, average='macro')
281
+ recall = recall_score(test_labels, test_preds, average='macro')
282
+ rocauc = roc_auc_score(test_labels_binarized, test_preds_binarized, average='macro')
283
+
284
+ print(f'Accuracy: {accuracy:.4f}')
285
+ print(f'F1 Score: {f1:.4f}')
286
+ print(f'Precision: {precision:.4f}')
287
+ print(f'Recall: {recall:.4f}')
288
+ print(f'ROC AUC: {rocauc:.4f}')
289
+
290
+ conf_matrix = confusion_matrix(test_labels, test_preds)
291
+ print("Confusion Matrix:")
292
+ print(conf_matrix)
codes/Fine-tuning/cell_type_classification/NuSPIRe_full_fine-tuning.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ import pandas as pd
7
+ from torchvision import transforms
8
+ from torch.utils.data import DataLoader, SubsetRandomSampler
9
+ from tqdm import tqdm
10
+ from transformers import ViTForImageClassification
11
+ from torchvision.datasets import ImageFolder
12
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, roc_auc_score, recall_score, confusion_matrix
13
+ from sklearn.preprocessing import label_binarize
14
+ from torch.optim.lr_scheduler import LambdaLR
15
+ import argparse
16
+
17
+
18
+ def set_seeds(seed_value=42, cuda_deterministic=False):
19
+ """Set seeds for reproducibility."""
20
+ random.seed(seed_value)
21
+ os.environ['PYTHONHASHSEED'] = str(seed_value)
22
+ np.random.seed(seed_value)
23
+ torch.manual_seed(seed_value)
24
+ if torch.cuda.is_available():
25
+ torch.cuda.manual_seed(seed_value)
26
+ torch.cuda.manual_seed_all(seed_value)
27
+ # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
28
+ if cuda_deterministic: # slower, more reproducible
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+ else: # faster, less reproducible
32
+ torch.backends.cudnn.deterministic = False
33
+ torch.backends.cudnn.benchmark = True
34
+
35
+ def warmup_lr_lambda(current_epoch: int, warmup_epochs: int):
36
+ if (current_epoch < warmup_epochs):
37
+ return float(current_epoch + 1) / float(max(1, warmup_epochs))
38
+ return 1.0
39
+
40
+ # set up
41
+ parser = argparse.ArgumentParser(description="Setup experiment parameters")
42
+ parser.add_argument('--num', type=int, required=True, help='Number of samples per class')
43
+ parser.add_argument('--device', type=int, default=0, help='CUDA device number (default: 0)')
44
+ parser.add_argument('--rep', type=int, required=True, help='Number of replicate')
45
+ args = parser.parse_args()
46
+ num_samples_per_class = args.num
47
+ device = args.device
48
+ num_repeats = args.rep
49
+
50
+ SEED = 42
51
+ DEVICE = torch.device(f"cuda:{device}")
52
+ DATA_DIR = '../lung5_rep1_cancer_nuclear_image_15micron/'
53
+ BATCH_SIZE = 300
54
+ NUM_EPOCHS = 30
55
+ PORJECT_NAME = f'Nuspire_{num_samples_per_class}_lung5_rep1_Classification'
56
+ set_seeds(SEED)
57
+ folder_name = f'./{PORJECT_NAME}_checkpoint'
58
+ if not os.path.exists(folder_name):
59
+ os.makedirs(folder_name)
60
+ print(f"'{folder_name}'has been created.")
61
+ else:
62
+ print(f"'{folder_name}' already exists.")
63
+
64
+ # Dataset
65
+ transform = transforms.Compose([
66
+ transforms.Resize((112, 112)),
67
+ transforms.Grayscale(),
68
+ transforms.RandomHorizontalFlip(p=0.5),
69
+ transforms.RandomVerticalFlip(p=0.5),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
72
+ ])
73
+
74
+ dataset = ImageFolder(DATA_DIR, transform=transform)
75
+ labels = [dataset[i][1] for i in range(len(dataset))]
76
+
77
+ # Define train and test sizes
78
+ train_size = int(0.8 * len(dataset))
79
+ valid_size = int(0.1 * len(dataset))
80
+ test_size = len(dataset) - train_size - valid_size
81
+
82
+ indices = np.arange(len(dataset))
83
+ np.random.shuffle(indices)
84
+
85
+ # Split
86
+ train_indices = indices[:train_size]
87
+ valid_indices = indices[train_size:train_size + valid_size]
88
+ test_indices = indices[train_size + valid_size:]
89
+ class_1_train_indices = [i for i in train_indices if labels[i] == 1]
90
+ class_2_train_indices = [i for i in train_indices if labels[i] == 2]
91
+ class_0_train_indices = [i for i in train_indices if labels[i] == 0]
92
+
93
+
94
+ for repeat in range(num_repeats):
95
+ np.random.shuffle(class_1_train_indices)
96
+ np.random.shuffle(class_2_train_indices)
97
+ np.random.shuffle(class_0_train_indices)
98
+
99
+ class_1_train_indices = class_1_train_indices[:num_samples_per_class]
100
+ class_2_train_indices = class_2_train_indices[:num_samples_per_class]
101
+ class_0_train_indices = class_0_train_indices[:num_samples_per_class]
102
+
103
+ balanced_train_indices = (
104
+ class_1_train_indices +
105
+ class_2_train_indices +
106
+ class_0_train_indices
107
+ )
108
+ np.random.shuffle(balanced_train_indices)
109
+
110
+ train_sampler = SubsetRandomSampler(balanced_train_indices)
111
+ valid_sampler = SubsetRandomSampler(valid_indices)
112
+ test_sampler = SubsetRandomSampler(test_indices)
113
+
114
+ # print(balanced_train_indices)
115
+ # print(valid_indices)
116
+ # print(test_indices)
117
+
118
+ train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers= 4)
119
+ valid_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=valid_sampler, num_workers= 4)
120
+ test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler, num_workers= 4)
121
+
122
+
123
+ # Model
124
+ model = ViTForImageClassification.from_pretrained("/mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69",num_labels=3)
125
+ model.to(DEVICE)
126
+
127
+ # Training
128
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
129
+ writer = SummaryWriter(f'./tensorboard/{PORJECT_NAME}')
130
+ step1 = 0
131
+ step2 = 0
132
+ best_val_loss = float('inf')
133
+ best_val_f1 = 0
134
+ warmup_epochs = 5
135
+ scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_lr_lambda(epoch, warmup_epochs))
136
+
137
+ for epoch in range(NUM_EPOCHS):
138
+ print(f"Epoch: {epoch+1}/{NUM_EPOCHS}")
139
+ model.train()
140
+ train_preds, train_labels = [], []
141
+ loss_list = []
142
+ for i, (x, l) in tqdm(enumerate(train_loader), total=len(train_loader)):
143
+ x = x.to(DEVICE)
144
+ l = l.to(DEVICE)
145
+
146
+ optimizer.zero_grad()
147
+
148
+ outputs = model(x, labels=l)
149
+
150
+ loss = outputs.loss
151
+
152
+ _, predicted = torch.max(outputs.logits, 1)
153
+ train_preds.extend(predicted.cpu().numpy())
154
+ train_labels.extend(l.cpu().numpy())
155
+
156
+ writer.add_scalar("Step/Train Loss", loss.item(), step1)
157
+ loss_list.append(loss.item())
158
+
159
+ step1 += 1
160
+ loss.backward()
161
+ optimizer.step()
162
+
163
+ train_loss = np.mean(loss_list)
164
+ train_accuracy = 100 * (np.array(train_preds) == np.array(train_labels)).mean()
165
+ train_f1 = f1_score(train_labels, train_preds, average='macro')
166
+ train_precision = precision_score(train_labels, train_preds, average='macro')
167
+
168
+ model.eval()
169
+ val_preds, val_labels = [], []
170
+ loss_list = []
171
+ with torch.no_grad():
172
+ for i, (x, l) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
173
+ x = x.to(DEVICE)
174
+ l = l.to(DEVICE)
175
+
176
+ outputs = model(x, labels=l)
177
+
178
+ loss = outputs.loss
179
+
180
+ _, predicted = torch.max(outputs.logits, 1)
181
+ val_preds.extend(predicted.cpu().numpy())
182
+ val_labels.extend(l.cpu().numpy())
183
+
184
+ writer.add_scalar("Step/Validation Loss", loss.item(), step2)
185
+
186
+ loss_list.append(loss.item())
187
+ step2 += 1
188
+ val_loss = np.mean(loss_list)
189
+ val_accuracy = 100 * (np.array(val_preds) == np.array(val_labels)).mean()
190
+ val_f1 = f1_score(val_labels, val_preds, average='macro')
191
+ val_precision = precision_score(val_labels, val_preds, average='macro')
192
+
193
+ val_labels_bin = label_binarize(val_labels, classes=[0, 1, 2])
194
+ val_preds_bin = label_binarize(val_preds, classes=[0, 1, 2])
195
+ val_auc = roc_auc_score(val_labels_bin, val_preds_bin, average='macro', multi_class='ovr')
196
+
197
+ # Save the model if the validation loss is the best we've seen so far.
198
+ if val_loss < best_val_loss:
199
+ torch.save(model.state_dict(), f'{folder_name}/{PORJECT_NAME}_best_loss_model.pt')
200
+ model.save_pretrained(f'{folder_name}/{PORJECT_NAME}_best_loss_model')
201
+ best_val_loss = val_loss
202
+
203
+ # Save the model if the validation F1 score is the best we've seen so far.
204
+ if val_f1 > best_val_f1:
205
+ torch.save(model.state_dict(), f'{folder_name}/{PORJECT_NAME}_best_f1_model.pt')
206
+ model.save_pretrained(f'{folder_name}/{PORJECT_NAME}_best_f1_model')
207
+ best_val_f1 = val_f1
208
+
209
+ lr = optimizer.param_groups[0]['lr']
210
+ writer.add_scalar("Epoch/Lr", lr, epoch)
211
+ writer.add_scalar("Epoch/Validation ROC AUC", val_auc, epoch)
212
+ writer.add_scalars("Epoch/Loss", {'Train Loss': train_loss, 'Validation Loss': val_loss}, epoch)
213
+ writer.add_scalars("Epoch/ACC", {'Train ACC': train_accuracy, 'Validation ACC': val_accuracy}, epoch)
214
+ writer.add_scalars("Epoch/Precision", {'Train Precision': train_precision, 'Validation Precision': val_precision}, epoch)
215
+ writer.add_scalars("Epoch/F1_Score", {'Train F1 Score': train_f1, 'Validation F1 Score': val_f1}, epoch)
216
+
217
+ print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Train ACC: {train_accuracy:.4f}%, Train F1: {train_f1:.4f}, Train Precision: {train_precision:.4f}")
218
+ print(f"Epoch {epoch}, Validation Loss: {val_loss:.4f}, Validation ACC: {val_accuracy:.4f}%, Validation F1: {val_f1:.4f}, Validation Precision: {val_precision:.4f}, Validation ROC AUC: {val_auc:.4f}")
219
+
220
+ scheduler.step()
221
+
222
+ # Test with best f1 model
223
+ transform = transforms.Compose([
224
+ transforms.Resize((112, 112)),
225
+ transforms.Grayscale(),
226
+ # transforms.RandomHorizontalFlip(p=0.5),
227
+ # transforms.RandomVerticalFlip(p=0.5),
228
+ transforms.ToTensor(),
229
+ transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
230
+ ])
231
+
232
+ dataset = ImageFolder(DATA_DIR, transform=transform)
233
+ test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler)
234
+ model_path = f'{folder_name}/{PORJECT_NAME}_best_f1_model.pt'
235
+ model.load_state_dict(torch.load(model_path))
236
+ model.to(DEVICE)
237
+ model.eval()
238
+ test_preds, test_labels = [], []
239
+ test_probs = []
240
+
241
+ with torch.no_grad():
242
+ for x, l in tqdm(test_loader, total=len(test_loader)):
243
+ x = x.to(DEVICE)
244
+ l = l.to(DEVICE)
245
+
246
+ outputs = model(x)
247
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
248
+ _, predicted = torch.max(probabilities, 1)
249
+
250
+ test_preds.extend(predicted.cpu().numpy())
251
+ test_labels.extend(l.cpu().numpy())
252
+ test_probs.extend(probabilities.cpu().numpy())
253
+
254
+ test_probs = np.array(test_probs)
255
+
256
+ df = pd.DataFrame({
257
+ 'True Labels': test_labels,
258
+ 'Predicted Labels': test_preds
259
+ })
260
+
261
+ for i in range(test_probs.shape[1]):
262
+ df[f'Prob_Class{i}'] = test_probs[:, i]
263
+
264
+ df.to_csv(f'{PORJECT_NAME}.csv', index=False)
265
+ print("Test labels, predictions, and probabilities have been saved")
266
+
267
+ test_labels_binarized = label_binarize(test_labels, classes=[0, 1, 2])
268
+ test_preds_binarized = label_binarize(test_preds, classes=[0, 1, 2])
269
+
270
+ accuracy = accuracy_score(test_labels, test_preds)
271
+ f1 = f1_score(test_labels, test_preds, average='macro')
272
+ precision = precision_score(test_labels, test_preds, average='macro')
273
+ recall = recall_score(test_labels, test_preds, average='macro')
274
+ rocauc = roc_auc_score(test_labels_binarized, test_preds_binarized, average='macro')
275
+
276
+ print(f'Accuracy: {accuracy:.4f}')
277
+ print(f'F1 Score: {f1:.4f}')
278
+ print(f'Precision: {precision:.4f}')
279
+ print(f'Recall: {recall:.4f}')
280
+ print(f'ROC AUC: {rocauc:.4f}')
281
+
282
+ conf_matrix = confusion_matrix(test_labels, test_preds)
283
+ print("Confusion Matrix:")
284
+ print(conf_matrix)
codes/Fine-tuning/cell_type_classification/NuSPIRe_partial_fine-tuning.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+ import numpy as np
5
+ import os
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import pandas as pd
8
+ from torchvision import transforms
9
+ from torch.utils.data import DataLoader, SubsetRandomSampler
10
+ from tqdm import tqdm
11
+ from transformers import ViTForImageClassification
12
+ from torchvision.datasets import ImageFolder
13
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, roc_auc_score, recall_score, confusion_matrix
14
+ from sklearn.preprocessing import label_binarize
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ import argparse
17
+
18
+
19
+ def set_seeds(seed_value=42, cuda_deterministic=False):
20
+ """Set seeds for reproducibility."""
21
+ random.seed(seed_value)
22
+ os.environ['PYTHONHASHSEED'] = str(seed_value)
23
+ np.random.seed(seed_value)
24
+ torch.manual_seed(seed_value)
25
+ if torch.cuda.is_available():
26
+ torch.cuda.manual_seed(seed_value)
27
+ torch.cuda.manual_seed_all(seed_value)
28
+ # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
29
+ if cuda_deterministic: # slower, more reproducible
30
+ torch.backends.cudnn.deterministic = True
31
+ torch.backends.cudnn.benchmark = False
32
+ else: # faster, less reproducible
33
+ torch.backends.cudnn.deterministic = False
34
+ torch.backends.cudnn.benchmark = True
35
+
36
+
37
+ def warmup_lr_lambda(current_epoch: int, warmup_epochs: int):
38
+ if (current_epoch < warmup_epochs):
39
+ return float(current_epoch + 1) / float(max(1, warmup_epochs))
40
+ return 1.0
41
+
42
+ # set up
43
+ parser = argparse.ArgumentParser(description="Setup experiment parameters")
44
+ parser.add_argument('--num', type=int, required=True, help='Number of samples per class')
45
+ parser.add_argument('--device', type=int, default=0, help='CUDA device number (default: 0)')
46
+ parser.add_argument('--rep', type=int, required=True, help='Number of replicate')
47
+ args = parser.parse_args()
48
+ num_samples_per_class = args.num
49
+ device = args.device
50
+ num_repeats = args.rep
51
+
52
+ SEED = 42
53
+ DEVICE = torch.device(f"cuda:{device}")
54
+ DATA_DIR = '../lung5_rep1_cancer_nuclear_image_15micron/'
55
+ BATCH_SIZE = 300
56
+ NUM_EPOCHS = 30
57
+ PORJECT_NAME = f'MLP_Frozen_{num_samples_per_class}_lung5_rep1_Classification'
58
+ set_seeds(SEED)
59
+ folder_name = f'./{PORJECT_NAME}_checkpoint'
60
+
61
+ if not os.path.exists(folder_name):
62
+ os.makedirs(folder_name)
63
+ print(f"'{folder_name}'has been created.")
64
+ else:
65
+ print(f"'{folder_name}' already exists.")
66
+
67
+ # Dataset
68
+ transform = transforms.Compose([
69
+ transforms.Resize((112, 112)),
70
+ transforms.Grayscale(),
71
+ transforms.RandomHorizontalFlip(p=0.5),
72
+ transforms.RandomVerticalFlip(p=0.5),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
75
+ ])
76
+
77
+ dataset = ImageFolder(DATA_DIR, transform=transform)
78
+ labels = [dataset[i][1] for i in range(len(dataset))]
79
+
80
+ # Define train and test sizes
81
+ train_size = int(0.8 * len(dataset))
82
+ valid_size = int(0.1 * len(dataset))
83
+ test_size = len(dataset) - train_size - valid_size
84
+
85
+ indices = np.arange(len(dataset))
86
+ np.random.shuffle(indices)
87
+
88
+ # Split
89
+ train_indices = indices[:train_size]
90
+ valid_indices = indices[train_size:train_size + valid_size]
91
+ test_indices = indices[train_size + valid_size:]
92
+ class_1_train_indices = [i for i in train_indices if labels[i] == 1]
93
+ class_2_train_indices = [i for i in train_indices if labels[i] == 2]
94
+ class_0_train_indices = [i for i in train_indices if labels[i] == 0]
95
+
96
+ for repeat in range(num_repeats):
97
+ np.random.shuffle(class_1_train_indices)
98
+ np.random.shuffle(class_2_train_indices)
99
+ np.random.shuffle(class_0_train_indices)
100
+
101
+ class_1_train_indices = class_1_train_indices[:num_samples_per_class]
102
+ class_2_train_indices = class_2_train_indices[:num_samples_per_class]
103
+ class_0_train_indices = class_0_train_indices[:num_samples_per_class]
104
+
105
+ balanced_train_indices = (
106
+ class_1_train_indices +
107
+ class_2_train_indices +
108
+ class_0_train_indices
109
+ )
110
+ np.random.shuffle(balanced_train_indices)
111
+
112
+ train_sampler = SubsetRandomSampler(balanced_train_indices)
113
+ valid_sampler = SubsetRandomSampler(valid_indices)
114
+ test_sampler = SubsetRandomSampler(test_indices)
115
+
116
+ # print(balanced_train_indices)
117
+ # print(valid_indices)
118
+ # print(test_indices)
119
+
120
+ train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers= 4)
121
+ valid_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=valid_sampler, num_workers= 4)
122
+ test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler, num_workers= 4)
123
+
124
+
125
+ # Model
126
+ model = ViTForImageClassification.from_pretrained("/mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69",num_labels=3)
127
+ for name, param in model.named_parameters():
128
+ if 'classifier' not in name:
129
+ param.requires_grad = False
130
+
131
+ class MLP(nn.Module):
132
+ def __init__(self, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, hidden_dim4, output_dim):
133
+ super(MLP, self).__init__()
134
+ self.fc1 = nn.Linear(input_dim, hidden_dim1)
135
+ self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
136
+ self.fc3 = nn.Linear(hidden_dim2, hidden_dim3)
137
+ self.fc4 = nn.Linear(hidden_dim3, hidden_dim4)
138
+ self.fc5 = nn.Linear(hidden_dim4, output_dim)
139
+ self.relu = nn.ReLU()
140
+
141
+ def forward(self, x):
142
+ x = self.relu(self.fc1(x))
143
+ x = self.relu(self.fc2(x))
144
+ x = self.relu(self.fc3(x))
145
+ x = self.relu(self.fc4(x))
146
+ x = self.fc5(x)
147
+ return x
148
+
149
+ model.classifier = MLP(input_dim=768, hidden_dim1=512, hidden_dim2=256, hidden_dim3=128, hidden_dim4=64, output_dim=3)
150
+ model.to(DEVICE)
151
+
152
+ # Training
153
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
154
+ writer = SummaryWriter(f'./tensorboard/{PORJECT_NAME}')
155
+ step1 = 0
156
+ step2 = 0
157
+ best_val_loss = float('inf')
158
+ best_val_f1 = 0
159
+ warmup_epochs = 5
160
+ scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_lr_lambda(epoch, warmup_epochs))
161
+
162
+
163
+ for epoch in range(NUM_EPOCHS):
164
+ print(f"Epoch: {epoch+1}/{NUM_EPOCHS}")
165
+ model.train()
166
+ train_preds, train_labels = [], []
167
+ loss_list = []
168
+ for i, (x, l) in tqdm(enumerate(train_loader), total=len(train_loader)):
169
+ x = x.to(DEVICE)
170
+ l = l.to(DEVICE)
171
+
172
+ optimizer.zero_grad()
173
+
174
+ outputs = model(x, labels=l)
175
+
176
+ loss = outputs.loss
177
+
178
+ _, predicted = torch.max(outputs.logits, 1)
179
+ train_preds.extend(predicted.cpu().numpy())
180
+ train_labels.extend(l.cpu().numpy())
181
+
182
+ writer.add_scalar("Step/Train Loss", loss.item(), step1)
183
+ loss_list.append(loss.item())
184
+
185
+ step1 += 1
186
+ loss.backward()
187
+ optimizer.step()
188
+
189
+ train_loss = np.mean(loss_list)
190
+ train_accuracy = 100 * (np.array(train_preds) == np.array(train_labels)).mean()
191
+ train_f1 = f1_score(train_labels, train_preds, average='macro')
192
+ train_precision = precision_score(train_labels, train_preds, average='macro')
193
+
194
+ model.eval()
195
+ val_preds, val_labels = [], []
196
+ loss_list = []
197
+ with torch.no_grad():
198
+ for i, (x, l) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
199
+ x = x.to(DEVICE)
200
+ l = l.to(DEVICE)
201
+
202
+ outputs = model(x, labels=l)
203
+
204
+ loss = outputs.loss
205
+
206
+ _, predicted = torch.max(outputs.logits, 1)
207
+ val_preds.extend(predicted.cpu().numpy())
208
+ val_labels.extend(l.cpu().numpy())
209
+
210
+ writer.add_scalar("Step/Validation Loss", loss.item(), step2)
211
+
212
+ loss_list.append(loss.item())
213
+ step2 += 1
214
+ val_loss = np.mean(loss_list)
215
+ val_accuracy = 100 * (np.array(val_preds) == np.array(val_labels)).mean()
216
+ val_f1 = f1_score(val_labels, val_preds, average='macro')
217
+ val_precision = precision_score(val_labels, val_preds, average='macro')
218
+
219
+ val_labels_bin = label_binarize(val_labels, classes=[0, 1, 2])
220
+ val_preds_bin = label_binarize(val_preds, classes=[0, 1, 2])
221
+ val_auc = roc_auc_score(val_labels_bin, val_preds_bin, average='macro', multi_class='ovr')
222
+
223
+ # Save the model if the validation loss is the best we've seen so far.
224
+ if val_loss < best_val_loss:
225
+ torch.save(model.state_dict(), f'{folder_name}/{PORJECT_NAME}_best_loss_model.pt')
226
+ model.save_pretrained(f'{folder_name}/{PORJECT_NAME}_best_loss_model')
227
+ best_val_loss = val_loss
228
+
229
+ # Save the model if the validation F1 score is the best we've seen so far.
230
+ if val_f1 > best_val_f1:
231
+ torch.save(model.state_dict(), f'{folder_name}/{PORJECT_NAME}_best_f1_model.pt')
232
+ model.save_pretrained(f'{folder_name}/{PORJECT_NAME}_best_f1_model')
233
+ best_val_f1 = val_f1
234
+
235
+ lr = optimizer.param_groups[0]['lr']
236
+ writer.add_scalar("Epoch/Lr", lr, epoch)
237
+ writer.add_scalar("Epoch/Validation ROC AUC", val_auc, epoch)
238
+ writer.add_scalars("Epoch/Loss", {'Train Loss': train_loss, 'Validation Loss': val_loss}, epoch)
239
+ writer.add_scalars("Epoch/ACC", {'Train ACC': train_accuracy, 'Validation ACC': val_accuracy}, epoch)
240
+ writer.add_scalars("Epoch/Precision", {'Train Precision': train_precision, 'Validation Precision': val_precision}, epoch)
241
+ writer.add_scalars("Epoch/F1_Score", {'Train F1 Score': train_f1, 'Validation F1 Score': val_f1}, epoch)
242
+
243
+ print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Train ACC: {train_accuracy:.4f}%, Train F1: {train_f1:.4f}, Train Precision: {train_precision:.4f}")
244
+ print(f"Epoch {epoch}, Validation Loss: {val_loss:.4f}, Validation ACC: {val_accuracy:.4f}%, Validation F1: {val_f1:.4f}, Validation Precision: {val_precision:.4f}, Validation ROC AUC: {val_auc:.4f}")
245
+
246
+ scheduler.step()
247
+
248
+ # Test with best f1 model
249
+ transform = transforms.Compose([
250
+ transforms.Resize((112, 112)),
251
+ transforms.Grayscale(),
252
+ # transforms.RandomHorizontalFlip(p=0.5),
253
+ # transforms.RandomVerticalFlip(p=0.5),
254
+ transforms.ToTensor(),
255
+ transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
256
+ ])
257
+
258
+ dataset = ImageFolder(DATA_DIR, transform=transform)
259
+ test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler)
260
+ model_path = f'{folder_name}/{PORJECT_NAME}_best_f1_model.pt'
261
+ model.load_state_dict(torch.load(model_path))
262
+ model.to(DEVICE)
263
+ model.eval()
264
+ test_preds, test_labels = [], []
265
+ test_probs = []
266
+
267
+ with torch.no_grad():
268
+ for x, l in tqdm(test_loader, total=len(test_loader)):
269
+ x = x.to(DEVICE)
270
+ l = l.to(DEVICE)
271
+
272
+ outputs = model(x)
273
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
274
+ _, predicted = torch.max(probabilities, 1)
275
+
276
+ test_preds.extend(predicted.cpu().numpy())
277
+ test_labels.extend(l.cpu().numpy())
278
+ test_probs.extend(probabilities.cpu().numpy())
279
+
280
+ test_probs = np.array(test_probs)
281
+
282
+ df = pd.DataFrame({
283
+ 'True Labels': test_labels,
284
+ 'Predicted Labels': test_preds
285
+ })
286
+
287
+ for i in range(test_probs.shape[1]):
288
+ df[f'Prob_Class{i}'] = test_probs[:, i]
289
+
290
+ df.to_csv(f'{PORJECT_NAME}.csv', index=False)
291
+ print("Test labels, predictions, and probabilities have been saved")
292
+
293
+ test_labels_binarized = label_binarize(test_labels, classes=[0, 1, 2])
294
+ test_preds_binarized = label_binarize(test_preds, classes=[0, 1, 2])
295
+
296
+ accuracy = accuracy_score(test_labels, test_preds)
297
+ f1 = f1_score(test_labels, test_preds, average='macro')
298
+ precision = precision_score(test_labels, test_preds, average='macro')
299
+ recall = recall_score(test_labels, test_preds, average='macro')
300
+ rocauc = roc_auc_score(test_labels_binarized, test_preds_binarized, average='macro')
301
+
302
+ print(f'Accuracy: {accuracy:.4f}')
303
+ print(f'F1 Score: {f1:.4f}')
304
+ print(f'Precision: {precision:.4f}')
305
+ print(f'Recall: {recall:.4f}')
306
+ print(f'ROC AUC: {rocauc:.4f}')
307
+
308
+ conf_matrix = confusion_matrix(test_labels, test_preds)
309
+ print("Confusion Matrix:")
310
+ print(conf_matrix)
codes/Fine-tuning/expression_prediction/NuSPIRe_full_fine-tuning.ipynb ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "b0410ca4",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "import random\n",
13
+ "import numpy as np\n",
14
+ "import os\n",
15
+ "from torch.utils.tensorboard import SummaryWriter\n",
16
+ "import pandas as pd\n",
17
+ "from torchvision import transforms\n",
18
+ "from PIL import Image\n",
19
+ "from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler\n",
20
+ "from tqdm import tqdm\n",
21
+ "from transformers import ViTForImageClassification, ViTConfig\n"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "id": "d2f99710",
27
+ "metadata": {},
28
+ "source": [
29
+ "# hyperparameter"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 2,
35
+ "id": "b1a22094",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "SEED = 42\n",
40
+ "DEVICE = torch.device(\"cuda:0\")\n",
41
+ "DATA_DIR = '../train_nucleus_128_with_env_15dis_cell_scale/all/'\n",
42
+ "BATCH_SIZE = 300\n",
43
+ "NUM_EPOCHS = 30\n",
44
+ "PORJECT_NAME = f'Nuspire_mouse_brain_Regression'"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "924045aa",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "def set_seeds(seed_value=42, cuda_deterministic=False):\n",
55
+ " \"\"\"Set seeds for reproducibility.\"\"\"\n",
56
+ " random.seed(seed_value)\n",
57
+ " os.environ['PYTHONHASHSEED'] = str(seed_value)\n",
58
+ " np.random.seed(seed_value)\n",
59
+ " torch.manual_seed(seed_value)\n",
60
+ " if torch.cuda.is_available():\n",
61
+ " torch.cuda.manual_seed(seed_value)\n",
62
+ " torch.cuda.manual_seed_all(seed_value)\n",
63
+ " # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html\n",
64
+ " if cuda_deterministic: # slower, more reproducible\n",
65
+ " torch.backends.cudnn.deterministic = True\n",
66
+ " torch.backends.cudnn.benchmark = False\n",
67
+ " else: # faster, less reproducible\n",
68
+ " torch.backends.cudnn.deterministic = False\n",
69
+ " torch.backends.cudnn.benchmark = True\n"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "id": "9caab5e1",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "set_seeds(SEED)\n",
80
+ "timestamp = \"07\"\n",
81
+ "folder_name = f'./{PORJECT_NAME}_{timestamp}_checkpoint'\n",
82
+ "\n",
83
+ "if not os.path.exists(folder_name):\n",
84
+ " os.makedirs(folder_name)\n",
85
+ " # print(f\"'{folder_name}'has been created.\")\n",
86
+ "else:\n",
87
+ " print(f\"'{folder_name}' already exists.\")"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 5,
93
+ "id": "4a354850",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "class ImageDataset(Dataset):\n",
98
+ " def __init__(self, data_dir, transform=None):\n",
99
+ " self.data_dir = data_dir\n",
100
+ " self.transform = transform\n",
101
+ " self.file_list = os.listdir(data_dir)\n",
102
+ " self.cell_expression = pd.read_csv('../processed_data/cell_expression_filtered_size_allgene.csv', index_col=0)\n",
103
+ "\n",
104
+ " def __len__(self):\n",
105
+ " return len(self.file_list)\n",
106
+ "\n",
107
+ " def __getitem__(self, idx):\n",
108
+ " img_name = os.path.join(self.data_dir, self.file_list[idx])\n",
109
+ " img_index = img_name.split(\"/\")[-1].replace('image_', '').replace('.png', '')\n",
110
+ " image = Image.open(img_name).convert('L')\n",
111
+ " if self.transform:\n",
112
+ " image = self.transform(image)\n",
113
+ " \n",
114
+ " if img_index in self.cell_expression.index:\n",
115
+ " target = self.cell_expression.loc[img_index].values\n",
116
+ " else:\n",
117
+ " target = None\n",
118
+ " return image, target"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 6,
124
+ "id": "96d9003d",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "transform = transforms.Compose([\n",
129
+ " transforms.Resize((112, 112)),\n",
130
+ " transforms.RandomHorizontalFlip(p=0.5),\n",
131
+ " transforms.RandomVerticalFlip(p=0.5),\n",
132
+ " transforms.ToTensor(),\n",
133
+ " transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])\n",
134
+ "])"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 7,
140
+ "id": "772d8f7b",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "dataset = ImageDataset(DATA_DIR, transform=transform)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "id": "c52f7512",
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "total_size = len(dataset)\n",
155
+ "train_size = int(total_size * 0.8)\n",
156
+ "remaining_size = total_size - train_size\n",
157
+ "\n",
158
+ "valid_size = int(remaining_size * 0.5)\n",
159
+ "test_size = remaining_size - valid_size\n",
160
+ "\n",
161
+ "indices = list(range(total_size))\n",
162
+ "np.random.shuffle(indices)\n",
163
+ "\n",
164
+ "train_indices = indices[:train_size]\n",
165
+ "remaining_indices = indices[train_size:]\n",
166
+ "valid_indices = remaining_indices[:valid_size]\n",
167
+ "test_indices = remaining_indices[valid_size:]\n",
168
+ "\n",
169
+ "train_sampler = SubsetRandomSampler(train_indices)\n",
170
+ "valid_sampler = SubsetRandomSampler(valid_indices)\n",
171
+ "test_sampler = SubsetRandomSampler(test_indices)\n",
172
+ "\n",
173
+ "train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=4)\n",
174
+ "valid_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=valid_sampler, num_workers=4)\n",
175
+ "test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler, num_workers=4)\n",
176
+ "\n",
177
+ "# print(train_indices)\n",
178
+ "# print(valid_indices)\n",
179
+ "# print(test_indices)"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "id": "c6e1da23",
185
+ "metadata": {},
186
+ "source": [
187
+ "# model"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 9,
193
+ "id": "224f2cab",
194
+ "metadata": {},
195
+ "outputs": [
196
+ {
197
+ "name": "stderr",
198
+ "output_type": "stream",
199
+ "text": [
200
+ "You are using a model of type vit_mae to instantiate a model of type vit. This is not supported for all configurations of models and can yield errors.\n",
201
+ "Some weights of ViTForImageClassification were not initialized from the model checkpoint at /mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
202
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
203
+ ]
204
+ }
205
+ ],
206
+ "source": [
207
+ "config = ViTConfig.from_pretrained(\"/mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69\")\n",
208
+ "\n",
209
+ "config.hidden_dropout_prob = 0\n",
210
+ "config.attention_probs_dropout_prob = 0\n",
211
+ "config.num_labels = 347\n",
212
+ "\n",
213
+ "model = ViTForImageClassification.from_pretrained(\n",
214
+ " \"/mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69\",\n",
215
+ " config=config\n",
216
+ ")"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": 10,
222
+ "id": "3704f75c",
223
+ "metadata": {},
224
+ "outputs": [
225
+ {
226
+ "data": {
227
+ "text/plain": [
228
+ "ViTForImageClassification(\n",
229
+ " (vit): ViTModel(\n",
230
+ " (embeddings): ViTEmbeddings(\n",
231
+ " (patch_embeddings): ViTPatchEmbeddings(\n",
232
+ " (projection): Conv2d(1, 768, kernel_size=(8, 8), stride=(8, 8))\n",
233
+ " )\n",
234
+ " (dropout): Dropout(p=0, inplace=False)\n",
235
+ " )\n",
236
+ " (encoder): ViTEncoder(\n",
237
+ " (layer): ModuleList(\n",
238
+ " (0-11): 12 x ViTLayer(\n",
239
+ " (attention): ViTAttention(\n",
240
+ " (attention): ViTSelfAttention(\n",
241
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
242
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
243
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
244
+ " (dropout): Dropout(p=0, inplace=False)\n",
245
+ " )\n",
246
+ " (output): ViTSelfOutput(\n",
247
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
248
+ " (dropout): Dropout(p=0, inplace=False)\n",
249
+ " )\n",
250
+ " )\n",
251
+ " (intermediate): ViTIntermediate(\n",
252
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
253
+ " (intermediate_act_fn): GELUActivation()\n",
254
+ " )\n",
255
+ " (output): ViTOutput(\n",
256
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
257
+ " (dropout): Dropout(p=0, inplace=False)\n",
258
+ " )\n",
259
+ " (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
260
+ " (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
261
+ " )\n",
262
+ " )\n",
263
+ " )\n",
264
+ " (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
265
+ " )\n",
266
+ " (classifier): Linear(in_features=768, out_features=347, bias=True)\n",
267
+ ")"
268
+ ]
269
+ },
270
+ "execution_count": 10,
271
+ "metadata": {},
272
+ "output_type": "execute_result"
273
+ }
274
+ ],
275
+ "source": [
276
+ "model.to(DEVICE)"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "id": "ea686b1b",
282
+ "metadata": {},
283
+ "source": [
284
+ "# Training"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": 11,
290
+ "id": "d9c8456e",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)\n",
295
+ "criterion = nn.MSELoss()"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 12,
301
+ "id": "18c5aee0",
302
+ "metadata": {},
303
+ "outputs": [
304
+ {
305
+ "name": "stdout",
306
+ "output_type": "stream",
307
+ "text": [
308
+ "Epoch: 1/30\n"
309
+ ]
310
+ },
311
+ {
312
+ "name": "stderr",
313
+ "output_type": "stream",
314
+ "text": [
315
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [04:57<00:00, 2.08s/it]\n",
316
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.37it/s]\n"
317
+ ]
318
+ },
319
+ {
320
+ "name": "stdout",
321
+ "output_type": "stream",
322
+ "text": [
323
+ "Epoch 0, Train Loss: 0.1923, Validation Loss: 0.1672\n",
324
+ "Epoch: 2/30\n"
325
+ ]
326
+ },
327
+ {
328
+ "name": "stderr",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [04:59<00:00, 2.10s/it]\n",
332
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.38it/s]\n"
333
+ ]
334
+ },
335
+ {
336
+ "name": "stdout",
337
+ "output_type": "stream",
338
+ "text": [
339
+ "Epoch 1, Train Loss: 0.1617, Validation Loss: 0.1588\n",
340
+ "Epoch: 3/30\n"
341
+ ]
342
+ },
343
+ {
344
+ "name": "stderr",
345
+ "output_type": "stream",
346
+ "text": [
347
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [04:59<00:00, 2.10s/it]\n",
348
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.37it/s]\n"
349
+ ]
350
+ },
351
+ {
352
+ "name": "stdout",
353
+ "output_type": "stream",
354
+ "text": [
355
+ "Epoch 2, Train Loss: 0.1528, Validation Loss: 0.1526\n",
356
+ "Epoch: 4/30\n"
357
+ ]
358
+ },
359
+ {
360
+ "name": "stderr",
361
+ "output_type": "stream",
362
+ "text": [
363
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:06<00:00, 2.14s/it]\n",
364
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.33it/s]\n"
365
+ ]
366
+ },
367
+ {
368
+ "name": "stdout",
369
+ "output_type": "stream",
370
+ "text": [
371
+ "Epoch 3, Train Loss: 0.1480, Validation Loss: 0.1482\n",
372
+ "Epoch: 5/30\n"
373
+ ]
374
+ },
375
+ {
376
+ "name": "stderr",
377
+ "output_type": "stream",
378
+ "text": [
379
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:07<00:00, 2.15s/it]\n",
380
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.34it/s]\n"
381
+ ]
382
+ },
383
+ {
384
+ "name": "stdout",
385
+ "output_type": "stream",
386
+ "text": [
387
+ "Epoch 4, Train Loss: 0.1445, Validation Loss: 0.1473\n",
388
+ "Epoch: 6/30\n"
389
+ ]
390
+ },
391
+ {
392
+ "name": "stderr",
393
+ "output_type": "stream",
394
+ "text": [
395
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:02<00:00, 2.12s/it]\n",
396
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.37it/s]\n"
397
+ ]
398
+ },
399
+ {
400
+ "name": "stdout",
401
+ "output_type": "stream",
402
+ "text": [
403
+ "Epoch 5, Train Loss: 0.1421, Validation Loss: 0.1455\n",
404
+ "Epoch: 7/30\n"
405
+ ]
406
+ },
407
+ {
408
+ "name": "stderr",
409
+ "output_type": "stream",
410
+ "text": [
411
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [07:31<00:00, 3.16s/it]\n",
412
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.38it/s]\n"
413
+ ]
414
+ },
415
+ {
416
+ "name": "stdout",
417
+ "output_type": "stream",
418
+ "text": [
419
+ "Epoch 6, Train Loss: 0.1400, Validation Loss: 0.1428\n",
420
+ "Epoch: 8/30\n"
421
+ ]
422
+ },
423
+ {
424
+ "name": "stderr",
425
+ "output_type": "stream",
426
+ "text": [
427
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:02<00:00, 2.12s/it]\n",
428
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.38it/s]\n"
429
+ ]
430
+ },
431
+ {
432
+ "name": "stdout",
433
+ "output_type": "stream",
434
+ "text": [
435
+ "Epoch 7, Train Loss: 0.1380, Validation Loss: 0.1426\n",
436
+ "Epoch: 9/30\n"
437
+ ]
438
+ },
439
+ {
440
+ "name": "stderr",
441
+ "output_type": "stream",
442
+ "text": [
443
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [04:55<00:00, 2.06s/it]\n",
444
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:12<00:00, 1.41it/s]"
445
+ ]
446
+ },
447
+ {
448
+ "name": "stdout",
449
+ "output_type": "stream",
450
+ "text": [
451
+ "Epoch 8, Train Loss: 0.1366, Validation Loss: 0.1429\n",
452
+ "Epoch: 10/30\n"
453
+ ]
454
+ },
455
+ {
456
+ "name": "stderr",
457
+ "output_type": "stream",
458
+ "text": [
459
+ "\n",
460
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [04:53<00:00, 2.05s/it]\n",
461
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.38it/s]\n"
462
+ ]
463
+ },
464
+ {
465
+ "name": "stdout",
466
+ "output_type": "stream",
467
+ "text": [
468
+ "Epoch 9, Train Loss: 0.1352, Validation Loss: 0.1408\n",
469
+ "Epoch: 11/30\n"
470
+ ]
471
+ },
472
+ {
473
+ "name": "stderr",
474
+ "output_type": "stream",
475
+ "text": [
476
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [04:57<00:00, 2.08s/it]\n",
477
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.33it/s]"
478
+ ]
479
+ },
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "Epoch 10, Train Loss: 0.1340, Validation Loss: 0.1419\n",
485
+ "Epoch: 12/30\n"
486
+ ]
487
+ },
488
+ {
489
+ "name": "stderr",
490
+ "output_type": "stream",
491
+ "text": [
492
+ "\n",
493
+ "100%|β–ˆοΏ½οΏ½β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [04:59<00:00, 2.09s/it]\n",
494
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.38it/s]"
495
+ ]
496
+ },
497
+ {
498
+ "name": "stdout",
499
+ "output_type": "stream",
500
+ "text": [
501
+ "Epoch 11, Train Loss: 0.1325, Validation Loss: 0.1418\n",
502
+ "Epoch: 13/30\n"
503
+ ]
504
+ },
505
+ {
506
+ "name": "stderr",
507
+ "output_type": "stream",
508
+ "text": [
509
+ "\n",
510
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [04:59<00:00, 2.09s/it]\n",
511
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.38it/s]\n"
512
+ ]
513
+ },
514
+ {
515
+ "name": "stdout",
516
+ "output_type": "stream",
517
+ "text": [
518
+ "Epoch 12, Train Loss: 0.1314, Validation Loss: 0.1403\n",
519
+ "Epoch: 14/30\n"
520
+ ]
521
+ },
522
+ {
523
+ "name": "stderr",
524
+ "output_type": "stream",
525
+ "text": [
526
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:00<00:00, 2.10s/it]\n",
527
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.37it/s]\n"
528
+ ]
529
+ },
530
+ {
531
+ "name": "stdout",
532
+ "output_type": "stream",
533
+ "text": [
534
+ "Epoch 13, Train Loss: 0.1304, Validation Loss: 0.1395\n",
535
+ "Epoch: 15/30\n"
536
+ ]
537
+ },
538
+ {
539
+ "name": "stderr",
540
+ "output_type": "stream",
541
+ "text": [
542
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:00<00:00, 2.10s/it]\n",
543
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.37it/s]\n"
544
+ ]
545
+ },
546
+ {
547
+ "name": "stdout",
548
+ "output_type": "stream",
549
+ "text": [
550
+ "Epoch 14, Train Loss: 0.1289, Validation Loss: 0.1393\n",
551
+ "Epoch: 16/30\n"
552
+ ]
553
+ },
554
+ {
555
+ "name": "stderr",
556
+ "output_type": "stream",
557
+ "text": [
558
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:31<00:00, 2.32s/it]\n",
559
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.17it/s]"
560
+ ]
561
+ },
562
+ {
563
+ "name": "stdout",
564
+ "output_type": "stream",
565
+ "text": [
566
+ "Epoch 15, Train Loss: 0.1279, Validation Loss: 0.1398\n",
567
+ "Epoch: 17/30\n"
568
+ ]
569
+ },
570
+ {
571
+ "name": "stderr",
572
+ "output_type": "stream",
573
+ "text": [
574
+ "\n",
575
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:02<00:00, 2.53s/it]\n",
576
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.16it/s]"
577
+ ]
578
+ },
579
+ {
580
+ "name": "stdout",
581
+ "output_type": "stream",
582
+ "text": [
583
+ "Epoch 16, Train Loss: 0.1265, Validation Loss: 0.1394\n",
584
+ "Epoch: 18/30\n"
585
+ ]
586
+ },
587
+ {
588
+ "name": "stderr",
589
+ "output_type": "stream",
590
+ "text": [
591
+ "\n",
592
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:01<00:00, 2.53s/it]\n",
593
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.15it/s]"
594
+ ]
595
+ },
596
+ {
597
+ "name": "stdout",
598
+ "output_type": "stream",
599
+ "text": [
600
+ "Epoch 17, Train Loss: 0.1254, Validation Loss: 0.1403\n",
601
+ "Epoch: 19/30\n"
602
+ ]
603
+ },
604
+ {
605
+ "name": "stderr",
606
+ "output_type": "stream",
607
+ "text": [
608
+ "\n",
609
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:03<00:00, 2.54s/it]\n",
610
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.17it/s]\n"
611
+ ]
612
+ },
613
+ {
614
+ "name": "stdout",
615
+ "output_type": "stream",
616
+ "text": [
617
+ "Epoch 18, Train Loss: 0.1241, Validation Loss: 0.1389\n",
618
+ "Epoch: 20/30\n"
619
+ ]
620
+ },
621
+ {
622
+ "name": "stderr",
623
+ "output_type": "stream",
624
+ "text": [
625
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:07<00:00, 2.57s/it]\n",
626
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.13it/s]"
627
+ ]
628
+ },
629
+ {
630
+ "name": "stdout",
631
+ "output_type": "stream",
632
+ "text": [
633
+ "Epoch 19, Train Loss: 0.1234, Validation Loss: 0.1399\n",
634
+ "Epoch: 21/30\n"
635
+ ]
636
+ },
637
+ {
638
+ "name": "stderr",
639
+ "output_type": "stream",
640
+ "text": [
641
+ "\n",
642
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:13<00:00, 2.61s/it]\n",
643
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:16<00:00, 1.10it/s]"
644
+ ]
645
+ },
646
+ {
647
+ "name": "stdout",
648
+ "output_type": "stream",
649
+ "text": [
650
+ "Epoch 20, Train Loss: 0.1225, Validation Loss: 0.1407\n",
651
+ "Epoch: 22/30\n"
652
+ ]
653
+ },
654
+ {
655
+ "name": "stderr",
656
+ "output_type": "stream",
657
+ "text": [
658
+ "\n",
659
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:13<00:00, 2.62s/it]\n",
660
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.14it/s]"
661
+ ]
662
+ },
663
+ {
664
+ "name": "stdout",
665
+ "output_type": "stream",
666
+ "text": [
667
+ "Epoch 21, Train Loss: 0.1213, Validation Loss: 0.1395\n",
668
+ "Epoch: 23/30\n"
669
+ ]
670
+ },
671
+ {
672
+ "name": "stderr",
673
+ "output_type": "stream",
674
+ "text": [
675
+ "\n",
676
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:15<00:00, 2.62s/it]\n",
677
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:16<00:00, 1.12it/s]"
678
+ ]
679
+ },
680
+ {
681
+ "name": "stdout",
682
+ "output_type": "stream",
683
+ "text": [
684
+ "Epoch 22, Train Loss: 0.1203, Validation Loss: 0.1399\n",
685
+ "Epoch: 24/30\n"
686
+ ]
687
+ },
688
+ {
689
+ "name": "stderr",
690
+ "output_type": "stream",
691
+ "text": [
692
+ "\n",
693
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:14<00:00, 2.62s/it]\n",
694
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.14it/s]"
695
+ ]
696
+ },
697
+ {
698
+ "name": "stdout",
699
+ "output_type": "stream",
700
+ "text": [
701
+ "Epoch 23, Train Loss: 0.1190, Validation Loss: 0.1400\n",
702
+ "Epoch: 25/30\n"
703
+ ]
704
+ },
705
+ {
706
+ "name": "stderr",
707
+ "output_type": "stream",
708
+ "text": [
709
+ "\n",
710
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:13<00:00, 2.61s/it]\n",
711
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.14it/s]"
712
+ ]
713
+ },
714
+ {
715
+ "name": "stdout",
716
+ "output_type": "stream",
717
+ "text": [
718
+ "Epoch 24, Train Loss: 0.1182, Validation Loss: 0.1413\n",
719
+ "Epoch: 26/30\n"
720
+ ]
721
+ },
722
+ {
723
+ "name": "stderr",
724
+ "output_type": "stream",
725
+ "text": [
726
+ "\n",
727
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:39<00:00, 2.37s/it]\n",
728
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.34it/s]"
729
+ ]
730
+ },
731
+ {
732
+ "name": "stdout",
733
+ "output_type": "stream",
734
+ "text": [
735
+ "Epoch 25, Train Loss: 0.1173, Validation Loss: 0.1407\n",
736
+ "Epoch: 27/30\n"
737
+ ]
738
+ },
739
+ {
740
+ "name": "stderr",
741
+ "output_type": "stream",
742
+ "text": [
743
+ "\n",
744
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [05:08<00:00, 2.16s/it]\n",
745
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:13<00:00, 1.37it/s]"
746
+ ]
747
+ },
748
+ {
749
+ "name": "stdout",
750
+ "output_type": "stream",
751
+ "text": [
752
+ "Epoch 26, Train Loss: 0.1163, Validation Loss: 0.1411\n",
753
+ "Epoch: 28/30\n"
754
+ ]
755
+ },
756
+ {
757
+ "name": "stderr",
758
+ "output_type": "stream",
759
+ "text": [
760
+ "\n",
761
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [10:22<00:00, 4.35s/it]\n",
762
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:32<00:00, 1.78s/it]"
763
+ ]
764
+ },
765
+ {
766
+ "name": "stdout",
767
+ "output_type": "stream",
768
+ "text": [
769
+ "Epoch 27, Train Loss: 0.1156, Validation Loss: 0.1412\n",
770
+ "Epoch: 29/30\n"
771
+ ]
772
+ },
773
+ {
774
+ "name": "stderr",
775
+ "output_type": "stream",
776
+ "text": [
777
+ "\n",
778
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [09:45<00:00, 4.09s/it]\n",
779
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:16<00:00, 1.09it/s]"
780
+ ]
781
+ },
782
+ {
783
+ "name": "stdout",
784
+ "output_type": "stream",
785
+ "text": [
786
+ "Epoch 28, Train Loss: 0.1144, Validation Loss: 0.1416\n",
787
+ "Epoch: 30/30\n"
788
+ ]
789
+ },
790
+ {
791
+ "name": "stderr",
792
+ "output_type": "stream",
793
+ "text": [
794
+ "\n",
795
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 143/143 [06:34<00:00, 2.76s/it]\n",
796
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.14it/s]"
797
+ ]
798
+ },
799
+ {
800
+ "name": "stdout",
801
+ "output_type": "stream",
802
+ "text": [
803
+ "Epoch 29, Train Loss: 0.1134, Validation Loss: 0.1412\n"
804
+ ]
805
+ },
806
+ {
807
+ "name": "stderr",
808
+ "output_type": "stream",
809
+ "text": [
810
+ "\n"
811
+ ]
812
+ }
813
+ ],
814
+ "source": [
815
+ "writer = SummaryWriter(f'./tensorboard/{PORJECT_NAME}_{timestamp}')\n",
816
+ "step1 = 0\n",
817
+ "step2 = 0\n",
818
+ "best_val_loss = 1\n",
819
+ "\n",
820
+ "for epoch in range(NUM_EPOCHS):\n",
821
+ " print(f\"Epoch: {epoch+1}/{NUM_EPOCHS}\")\n",
822
+ " model.train()\n",
823
+ " loss_list = []\n",
824
+ " for i, (x,l) in tqdm(enumerate(train_loader), total=len(train_loader)):\n",
825
+ " x = x.to(DEVICE)\n",
826
+ " l = l.to(DEVICE)\n",
827
+ " \n",
828
+ " optimizer.zero_grad()\n",
829
+ " \n",
830
+ " outputs = model(x)\n",
831
+ " \n",
832
+ " loss = criterion(outputs.logits, l.float())\n",
833
+ " \n",
834
+ " writer.add_scalar(\"Step/Train Loss\", loss.item(),step1)\n",
835
+ " loss_list.append(loss.item())\n",
836
+ " \n",
837
+ " step1+=1\n",
838
+ " loss.backward()\n",
839
+ " optimizer.step()\n",
840
+ " train_loss = np.mean(loss_list)\n",
841
+ "\n",
842
+ " model.eval()\n",
843
+ " loss_list = []\n",
844
+ " with torch.no_grad():\n",
845
+ " for i, (x,l) in tqdm(enumerate(valid_loader), total=len(valid_loader)):\n",
846
+ " x = x.to(DEVICE)\n",
847
+ " l = l.to(DEVICE)\n",
848
+ "\n",
849
+ " optimizer.zero_grad()\n",
850
+ "\n",
851
+ " outputs = model(x)\n",
852
+ "\n",
853
+ " loss = criterion(outputs.logits, l.float())\n",
854
+ " \n",
855
+ " writer.add_scalar(\"Step/Validation Loss\", loss.item(),step2)\n",
856
+ "\n",
857
+ " loss_list.append(loss.item())\n",
858
+ " step2+=1\n",
859
+ " val_loss = np.mean(loss_list)\n",
860
+ " \n",
861
+ " # Save the model if the validation loss is the best we've seen so far.\n",
862
+ " if val_loss < best_val_loss:\n",
863
+ " torch.save(model.state_dict(), f'{folder_name}/{PORJECT_NAME}_best_model.pt')\n",
864
+ " model.save_pretrained(f'{folder_name}/{PORJECT_NAME}_best_model')\n",
865
+ " best_epoch=epoch\n",
866
+ " best_val_loss = val_loss\n",
867
+ "\n",
868
+ " lr = optimizer.param_groups[0]['lr']\n",
869
+ " writer.add_scalar(\"Epoch/Lr\", lr, epoch)\n",
870
+ " writer.add_scalars(\"Epoch/Loss\",{'Train Loss':train_loss,'Validation Loss':val_loss},epoch)\n",
871
+ " print(f\"Epoch {epoch}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}\")\n"
872
+ ]
873
+ },
874
+ {
875
+ "cell_type": "markdown",
876
+ "id": "0af358d9",
877
+ "metadata": {},
878
+ "source": [
879
+ "# Test"
880
+ ]
881
+ },
882
+ {
883
+ "cell_type": "code",
884
+ "execution_count": 13,
885
+ "id": "c7a6c5db",
886
+ "metadata": {},
887
+ "outputs": [],
888
+ "source": [
889
+ "transform = transforms.Compose([\n",
890
+ " transforms.Resize((112, 112)),\n",
891
+ " # transforms.RandomHorizontalFlip(p=0.5),\n",
892
+ " # transforms.RandomVerticalFlip(p=0.5),\n",
893
+ " transforms.ToTensor(),\n",
894
+ " transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])\n",
895
+ "])"
896
+ ]
897
+ },
898
+ {
899
+ "cell_type": "code",
900
+ "execution_count": 14,
901
+ "id": "aa667594",
902
+ "metadata": {},
903
+ "outputs": [],
904
+ "source": [
905
+ "dataset = ImageDataset(DATA_DIR, transform=transform)"
906
+ ]
907
+ },
908
+ {
909
+ "cell_type": "code",
910
+ "execution_count": 15,
911
+ "id": "85b09b94",
912
+ "metadata": {},
913
+ "outputs": [],
914
+ "source": [
915
+ "test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler, num_workers=4)"
916
+ ]
917
+ },
918
+ {
919
+ "cell_type": "code",
920
+ "execution_count": 16,
921
+ "id": "2e43c973",
922
+ "metadata": {},
923
+ "outputs": [
924
+ {
925
+ "data": {
926
+ "text/plain": [
927
+ "ViTForImageClassification(\n",
928
+ " (vit): ViTModel(\n",
929
+ " (embeddings): ViTEmbeddings(\n",
930
+ " (patch_embeddings): ViTPatchEmbeddings(\n",
931
+ " (projection): Conv2d(1, 768, kernel_size=(8, 8), stride=(8, 8))\n",
932
+ " )\n",
933
+ " (dropout): Dropout(p=0, inplace=False)\n",
934
+ " )\n",
935
+ " (encoder): ViTEncoder(\n",
936
+ " (layer): ModuleList(\n",
937
+ " (0-11): 12 x ViTLayer(\n",
938
+ " (attention): ViTAttention(\n",
939
+ " (attention): ViTSelfAttention(\n",
940
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
941
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
942
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
943
+ " (dropout): Dropout(p=0, inplace=False)\n",
944
+ " )\n",
945
+ " (output): ViTSelfOutput(\n",
946
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
947
+ " (dropout): Dropout(p=0, inplace=False)\n",
948
+ " )\n",
949
+ " )\n",
950
+ " (intermediate): ViTIntermediate(\n",
951
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
952
+ " (intermediate_act_fn): GELUActivation()\n",
953
+ " )\n",
954
+ " (output): ViTOutput(\n",
955
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
956
+ " (dropout): Dropout(p=0, inplace=False)\n",
957
+ " )\n",
958
+ " (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
959
+ " (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
960
+ " )\n",
961
+ " )\n",
962
+ " )\n",
963
+ " (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
964
+ " )\n",
965
+ " (classifier): Linear(in_features=768, out_features=347, bias=True)\n",
966
+ ")"
967
+ ]
968
+ },
969
+ "execution_count": 16,
970
+ "metadata": {},
971
+ "output_type": "execute_result"
972
+ }
973
+ ],
974
+ "source": [
975
+ "model_path = f'{folder_name}/{PORJECT_NAME}_best_model.pt'\n",
976
+ "model.load_state_dict(torch.load(model_path))\n",
977
+ "model.to(DEVICE) "
978
+ ]
979
+ },
980
+ {
981
+ "cell_type": "code",
982
+ "execution_count": 17,
983
+ "id": "08a01b05",
984
+ "metadata": {},
985
+ "outputs": [
986
+ {
987
+ "name": "stderr",
988
+ "output_type": "stream",
989
+ "text": [
990
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:15<00:00, 1.14it/s]\n"
991
+ ]
992
+ }
993
+ ],
994
+ "source": [
995
+ "model.eval()\n",
996
+ "true_labels = []\n",
997
+ "predicted_outputs = []\n",
998
+ "\n",
999
+ "with torch.no_grad():\n",
1000
+ " for i, (x, l) in tqdm(enumerate(test_loader), total=len(test_loader)):\n",
1001
+ " x = x.to(DEVICE)\n",
1002
+ " l = l.to(DEVICE)\n",
1003
+ "\n",
1004
+ " outputs = model(x)\n",
1005
+ "\n",
1006
+ " # Collect true labels and predicted outputs\n",
1007
+ " true_labels.append(l.cpu())\n",
1008
+ " predicted_outputs.append(outputs.logits.cpu())\n",
1009
+ " \n",
1010
+ " true_labels = torch.cat(true_labels).numpy()\n",
1011
+ " predicted_outputs = torch.cat(predicted_outputs).numpy() "
1012
+ ]
1013
+ },
1014
+ {
1015
+ "cell_type": "code",
1016
+ "execution_count": 18,
1017
+ "id": "10fdc2f5",
1018
+ "metadata": {},
1019
+ "outputs": [],
1020
+ "source": [
1021
+ "np.save(f'{PORJECT_NAME}_{timestamp}_all_outputs.npy', predicted_outputs)\n",
1022
+ "np.save(f'{PORJECT_NAME}_{timestamp}_all_targets.npy', true_labels)"
1023
+ ]
1024
+ },
1025
+ {
1026
+ "cell_type": "code",
1027
+ "execution_count": 19,
1028
+ "id": "ee0c734a",
1029
+ "metadata": {},
1030
+ "outputs": [
1031
+ {
1032
+ "name": "stdout",
1033
+ "output_type": "stream",
1034
+ "text": [
1035
+ "MSE: [0.13185952 0.09007648 0.16135108 ... 0.12363786 0.18178999 0.25879715]\n",
1036
+ "Pearson: [0.77104155 0.82377504 0.70091003 ... 0.77836889 0.64078275 0.49941442]\n",
1037
+ "MSE - Mean: 0.1386, Std: 0.0514\n",
1038
+ "Pearson - Mean: 0.7380, Std: 0.1021\n"
1039
+ ]
1040
+ }
1041
+ ],
1042
+ "source": [
1043
+ "from scipy.stats import pearsonr\n",
1044
+ "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error, explained_variance_score\n",
1045
+ "\n",
1046
+ "n_samples, n_features = true_labels.shape\n",
1047
+ "\n",
1048
+ "results = {metric: [] for metric in ['MSE',\n",
1049
+ " # 'RMSE',\n",
1050
+ " # 'MAE', \n",
1051
+ " # 'MAPE', \n",
1052
+ " # 'R_squared', \n",
1053
+ " # 'Explained_Variance',\n",
1054
+ " 'Pearson']}\n",
1055
+ "\n",
1056
+ "for i in range(n_samples):\n",
1057
+ " mse = mean_squared_error(true_labels[i, :], predicted_outputs[i, :])\n",
1058
+ " # rmse = np.sqrt(mse)\n",
1059
+ " # mae = mean_absolute_error(true_labels[i, :], predicted_outputs[i, :])\n",
1060
+ " # mape = mean_absolute_percentage_error(true_labels[i, :], predicted_outputs[i, :])\n",
1061
+ " # r2 = r2_score(true_labels[i, :], predicted_outputs[i, :])\n",
1062
+ " # explained_var = explained_variance_score(true_labels[i, :], predicted_outputs[i, :])\n",
1063
+ " pcc, _ = pearsonr(true_labels[i, :], predicted_outputs[i, :])\n",
1064
+ "\n",
1065
+ " results['MSE'].append(mse)\n",
1066
+ " # results['RMSE'].append(rmse)\n",
1067
+ " # results['MAE'].append(mae)\n",
1068
+ " # results['MAPE'].append(mape)\n",
1069
+ " # results['R_squared'].append(r2)\n",
1070
+ " # results['Explained_Variance'].append(explained_var)\n",
1071
+ " results['Pearson'].append(pcc)\n",
1072
+ "\n",
1073
+ "for metric in results:\n",
1074
+ " results[metric] = np.array(results[metric])\n",
1075
+ "\n",
1076
+ "for metric in results:\n",
1077
+ " print(f\"{metric}: {results[metric]}\")\n",
1078
+ "\n",
1079
+ "for metric in results:\n",
1080
+ " print(f\"{metric} - Mean: {np.mean(results[metric]):.4f}, Std: {np.std(results[metric]):.4f}\")"
1081
+ ]
1082
+ }
1083
+ ],
1084
+ "metadata": {
1085
+ "kernelspec": {
1086
+ "display_name": "Python 3 (ipykernel)",
1087
+ "language": "python",
1088
+ "name": "python3"
1089
+ },
1090
+ "language_info": {
1091
+ "codemirror_mode": {
1092
+ "name": "ipython",
1093
+ "version": 3
1094
+ },
1095
+ "file_extension": ".py",
1096
+ "mimetype": "text/x-python",
1097
+ "name": "python",
1098
+ "nbconvert_exporter": "python",
1099
+ "pygments_lexer": "ipython3",
1100
+ "version": "3.9.15"
1101
+ }
1102
+ },
1103
+ "nbformat": 4,
1104
+ "nbformat_minor": 5
1105
+ }
pretraining_pl_DDP_v5.py β†’ codes/Pre-training/pretraining.py RENAMED
File without changes