cconsti commited on
Commit
bd39841
·
verified ·
1 Parent(s): 2ec6283

Upload 5 files

Browse files
AI Cancer Cell Types App.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
5
+ from pytorch_lightning.loggers import TensorBoardLogger
6
+ import pandas as pd
7
+ from sklearn.model_selection import train_test_split
8
+ from torch.utils.data import DataLoader
9
+ from datasets import TrainMicrographDataset, ValidationMicrographDataset, InferenceMicrographDataset
10
+ from model import MicrographCleaner, find_best_model, find_optimal_threshold, prepare_submission
11
+ import gradio as gr
12
+
13
+ # Kaggle Setup explicitly
14
+ os.system("mkdir -p ~/.kaggle")
15
+ os.system("cp kaggle.json ~/.kaggle/")
16
+ os.system("chmod 600 ~/.kaggle/kaggle.json")
17
+ os.system("kaggle competitions download -c micrographs-competition")
18
+ os.system("unzip -n micrographs-competition.zip")
19
+
20
+ # Verify data
21
+ assert os.path.isfile("train.csv"), "Error, train.csv not found"
22
+ assert os.path.isfile("test.csv"), "Error, test.csv not found"
23
+
24
+ # Hyperparameters explicitly
25
+ WINDOW_SIZE = 512
26
+ BATCH_SIZE = 8
27
+ N_EPOCHS = 20
28
+
29
+ def train_and_generate_submission():
30
+ # Load data explicitly
31
+ train_df = pd.read_csv('train.csv')
32
+ test_df = pd.read_csv('test.csv')
33
+
34
+ train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)
35
+
36
+ train_dataset = TrainMicrographDataset(train_df, WINDOW_SIZE)
37
+ val_dataset = ValidationMicrographDataset(val_df, WINDOW_SIZE)
38
+
39
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
40
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2)
41
+
42
+ # Initialize model explicitly
43
+ model = MicrographCleaner(n_hidden_layers=12, n_kernels=24, kernel_size=5)
44
+
45
+ # Logger and callbacks explicitly
46
+ logger = TensorBoardLogger('lightning_logs', name='micrograph_cleaner')
47
+ checkpoint_callback = ModelCheckpoint(monitor='val_iou', mode='max', dirpath='checkpoints', filename='best-model')
48
+ early_stop_callback = EarlyStopping(monitor='val_iou', patience=5, mode='max')
49
+
50
+ # Trainer explicitly
51
+ trainer = pl.Trainer(max_epochs=N_EPOCHS, accelerator='auto', devices=1,
52
+ logger=logger, callbacks=[checkpoint_callback, early_stop_callback])
53
+
54
+ # Explicitly train the model
55
+ trainer.fit(model, train_loader, val_loader)
56
+
57
+ # Find the best model explicitly
58
+ best_model_path = find_best_model('checkpoints')
59
+ best_model = MicrographCleaner.load_from_checkpoint(best_model_path)
60
+
61
+ # Find optimal threshold explicitly
62
+ optimal_threshold = find_optimal_threshold(best_model, val_loader)
63
+
64
+ # Prepare submission explicitly
65
+ submission_df = prepare_submission(
66
+ best_model, test_df,
67
+ window_size=WINDOW_SIZE,
68
+ threshold=optimal_threshold,
69
+ overlap=0.65,
70
+ post_process=True,
71
+ use_tta=True
72
+ )
73
+
74
+ submission_df.to_csv('submission.csv', index=False)
75
+
76
+ return "submission.csv generated successfully!"
77
+
78
+ iface = gr.Interface(
79
+ fn=train_and_generate_submission,
80
+ inputs=None,
81
+ outputs="text",
82
+ title="Micrograph Model Trainer & Submission",
83
+ description="Click submit to explicitly train your model and generate submission.csv"
84
+ )
85
+
86
+ iface.launch()
AI Cancer Cell Types Dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import zlib
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import torchvision.transforms.v2 as transforms
8
+
9
+ def decode_array(encoded_base64_str):
10
+ decoded = base64.b64decode(encoded_base64_str)
11
+ decompressed = zlib.decompress(decoded)
12
+ return np.load(io.BytesIO(decompressed))
13
+
14
+ class BaseMicrographDataset(Dataset):
15
+ def __init__(self, df, window_size: int):
16
+ self.df = df
17
+ self.window_size = window_size
18
+
19
+ def __len__(self):
20
+ return len(self.df)
21
+
22
+ def load_and_normalize_image(self, encoded_image: str):
23
+ image = decode_array(encoded_image).astype(np.float32)
24
+ p_low, p_high = np.percentile(image, [2, 98])
25
+ image = np.clip((image - p_low) / (p_high - p_low + 1e-8), 0, 1)
26
+ if len(image.shape) == 2:
27
+ image = image[np.newaxis, ...]
28
+ return torch.from_numpy(image)
29
+
30
+ def load_mask(self, encoded_mask: str):
31
+ mask = decode_array(encoded_mask).astype(np.float32)
32
+ if len(mask.shape) == 2:
33
+ mask = mask[np.newaxis, ...]
34
+ return torch.from_numpy(mask)
35
+
36
+ def pad_to_min_size(self, image: torch.Tensor, min_size: int):
37
+ _, h, w = image.shape
38
+ pad_h = max(0, min_size - h)
39
+ pad_w = max(0, min_size - w)
40
+ padded = torch.nn.functional.pad(image, (0, pad_w, 0, pad_h), mode="reflect")
41
+ return padded, (pad_h, pad_w)
42
+
43
+ class TrainMicrographDataset(BaseMicrographDataset):
44
+ def __init__(self, df, window_size: int):
45
+ super().__init__(df, window_size)
46
+ self.shared_transform = transforms.Compose([
47
+ transforms.RandomCrop(window_size),
48
+ transforms.RandomVerticalFlip(p=0.5),
49
+ transforms.RandomHorizontalFlip(p=0.5),
50
+ transforms.RandomAffine(degrees=45, translate=(0.15, 0.15), scale=(0.85, 1.15), fill=0)
51
+ ])
52
+ self.image_only_transforms = transforms.Compose([
53
+ transforms.GaussianBlur(7, sigma=(0.1, 2.0)),
54
+ transforms.ColorJitter(brightness=0.3, contrast=0.3),
55
+ transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.05 if np.random.random() < 0.3 else x)
56
+ ])
57
+
58
+ def __getitem__(self, idx):
59
+ row = self.df.iloc[idx]
60
+ image = self.load_and_normalize_image(row['image'])
61
+ image, _ = self.pad_to_min_size(image, self.window_size)
62
+ mask = self.load_mask(row['mask'])
63
+ mask, _ = self.pad_to_min_size(mask, self.window_size)
64
+ stacked = torch.cat([image, mask], dim=0)
65
+ stacked = self.shared_transform(stacked)
66
+ image, mask = torch.split(stacked, [1, 1], dim=0)
67
+ return image, mask
68
+
69
+ class ValidationMicrographDataset(BaseMicrographDataset):
70
+ def __init__(self, df, window_size: int):
71
+ super().__init__(df, window_size)
72
+ self.n_crops = 5
73
+
74
+ def __len__(self):
75
+ return len(self.df) * self.n_crops
76
+
77
+ def get_crop_coordinates(self, image_shape, crop_idx):
78
+ h, w = image_shape
79
+ if crop_idx == 4:
80
+ h_start = (h - self.window_size) // 2
81
+ w_start = (w - self.window_size) // 2
82
+ else:
83
+ h_start = 0 if crop_idx < 2 else h - self.window_size
84
+ w_start = 0 if crop_idx % 2 == 0 else w - self.window_size
85
+ return h_start, w_start
86
+
87
+ def crop_tensors(self, image, mask, h_start, w_start):
88
+ h_end = h_start + self.window_size
89
+ w_end = w_start + self.window_size
90
+ return (image[:, h_start:h_end, w_start:w_end], mask[:, h_start:h_end, w_start:w_end])
91
+
92
+ def __getitem__(self, idx):
93
+ image_idx = idx // self.n_crops
94
+ crop_idx = idx % self.n_crops
95
+ row = self.df.iloc[image_idx]
96
+ image = self.load_and_normalize_image(row['image'])
97
+ image, _ = self.pad_to_min_size(image, self.window_size)
98
+ mask = self.load_mask(row['mask'])
99
+ mask, _ = self.pad_to_min_size(mask, self.window_size)
100
+ h_start, w_start = self.get_crop_coordinates(image.shape[1:], crop_idx)
101
+ image, mask = self.crop_tensors(image, mask, h_start, w_start)
102
+ return image, mask
103
+
104
+ class InferenceMicrographDataset(BaseMicrographDataset):
105
+ def __getitem__(self, idx):
106
+ row = self.df.iloc[idx]
107
+ image = self.load_and_normalize_image(row['image'])
108
+ image, padding = self.pad_to_min_size(image, self.window_size)
109
+ return image, row['Id'], padding
AI Cancer Cell Types Kaggle.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"username":"constantinhatecke7","key":"25b8ef89f2f12e940c7c693ac4083c06"}
AI Cancer Cell Types Model.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pytorch_lightning as pl
4
+ import numpy as np
5
+ import tqdm
6
+ import os
7
+ import re
8
+ from pathlib import Path
9
+ import pandas as pd
10
+ from torch.utils.data import DataLoader
11
+ from datasets import ValidationMicrographDataset, InferenceMicrographDataset
12
+ from scipy import ndimage
13
+ from utils import encode_array, decode_array
14
+
15
+ # SimpleCNN explicitly defined
16
+ class SimpleCNN(nn.Module):
17
+ def __init__(self, n_hidden_layers, n_kernels, kernel_size):
18
+ super().__init__()
19
+ layers = [nn.Conv2d(1, n_kernels, kernel_size=kernel_size, padding='same'),
20
+ nn.GroupNorm(4, n_kernels),
21
+ nn.PReLU()]
22
+
23
+ for _ in range(n_hidden_layers):
24
+ layers.extend([
25
+ nn.Conv2d(n_kernels, n_kernels, kernel_size=kernel_size, padding='same'),
26
+ nn.GroupNorm(4, n_kernels),
27
+ nn.PReLU(),
28
+ ])
29
+
30
+ layers.extend([
31
+ nn.Conv2d(n_kernels, 1, kernel_size=1),
32
+ nn.Sigmoid()
33
+ ])
34
+
35
+ self.conv_layers = nn.Sequential(*layers)
36
+
37
+ def forward(self, x):
38
+ return self.conv_layers(x)
39
+
40
+ # Lightning module wrapper explicitly defined
41
+ class MicrographCleaner(pl.LightningModule):
42
+ def __init__(self, n_hidden_layers=12, n_kernels=24, kernel_size=5, learning_rate=0.001):
43
+ super().__init__()
44
+ self.save_hyperparameters()
45
+ self.model = SimpleCNN(n_hidden_layers, n_kernels, kernel_size)
46
+ self.lossF = nn.BCELoss()
47
+ self.learning_rate = learning_rate
48
+
49
+ def forward(self, x):
50
+ return self.model(x)
51
+
52
+ def dice_loss(self, pred, target):
53
+ smooth = 1.0
54
+ pred_flat = pred.view(-1)
55
+ target_flat = target.view(-1)
56
+ intersection = (pred_flat * target_flat).sum()
57
+ union = pred_flat.sum() + target_flat.sum()
58
+ dice = (2.0 * intersection + smooth) / (union + smooth)
59
+ return 1.0 - dice
60
+
61
+ def focal_loss(self, pred, target, alpha=0.8, gamma=2.0):
62
+ bce = self.lossF(pred, target)
63
+ pt = target * pred + (1 - target) * (1 - pred)
64
+ focal_weight = (1 - pt) ** gamma
65
+ alpha_weight = target * alpha + (1 - target) * (1 - alpha)
66
+ return (focal_weight * alpha_weight * bce).mean()
67
+
68
+ def iou_score(self, pred, target, threshold=0.5):
69
+ pred_binary = (pred > threshold).float()
70
+ target_binary = (target > threshold).float()
71
+ intersection = (pred_binary * target_binary).sum()
72
+ union = pred_binary.sum() + target_binary.sum() - intersection
73
+ return (intersection + 1e-6) / (union + 1e-6)
74
+
75
+ def training_step(self, batch, batch_idx):
76
+ images, masks = batch
77
+ outputs = self(images)
78
+ loss = (0.2 * self.lossF(outputs, masks) +
79
+ 0.5 * self.dice_loss(outputs, masks) +
80
+ 0.3 * self.focal_loss(outputs, masks))
81
+ self.log('train_loss', loss, prog_bar=True)
82
+ return loss
83
+
84
+ def validation_step(self, batch, batch_idx):
85
+ images, masks = batch
86
+ outputs = self(images)
87
+ loss = (0.2 * self.lossF(outputs, masks) +
88
+ 0.5 * self.dice_loss(outputs, masks) +
89
+ 0.3 * self.focal_loss(outputs, masks))
90
+ iou = self.iou_score(outputs, masks)
91
+ self.log('val_loss', loss, prog_bar=True)
92
+ self.log('val_iou', iou, prog_bar=True)
93
+
94
+ def configure_optimizers(self):
95
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=1e-4)
96
+ return optimizer
97
+
98
+ # explicitly helper functions related to the model:
99
+
100
+ def find_best_model(checkpoint_dir: str = "checkpoints") -> str:
101
+ pattern = r"micrograph-epoch=(\d+)-val_loss=(\d+\.\d+)\.ckpt"
102
+ best_loss = float('inf')
103
+ best_checkpoint = None
104
+
105
+ for filename in os.listdir(checkpoint_dir):
106
+ match = re.match(pattern, filename)
107
+ if match:
108
+ val_loss = float(match.group(2))
109
+ if val_loss < best_loss:
110
+ best_loss = val_loss
111
+ best_checkpoint = filename
112
+
113
+ if best_checkpoint is None:
114
+ raise ValueError("No valid checkpoint files found")
115
+
116
+ return str(Path(checkpoint_dir) / best_checkpoint)
117
+
118
+ def find_optimal_threshold(model, val_loader, thresholds=np.arange(0.3, 0.7, 0.05)):
119
+ best_iou = 0
120
+ best_threshold = 0.5
121
+ all_preds, all_targets = [], []
122
+
123
+ with torch.no_grad():
124
+ for images, masks in tqdm.tqdm(val_loader):
125
+ outputs = model(images)
126
+ all_preds.append(outputs.cpu())
127
+ all_targets.append(masks.cpu())
128
+
129
+ all_preds = torch.cat(all_preds)
130
+ all_targets = torch.cat(all_targets)
131
+
132
+ for threshold in thresholds:
133
+ iou = model.iou_score(all_preds, all_targets, threshold=threshold)
134
+ if iou > best_iou:
135
+ best_iou = iou
136
+ best_threshold = threshold
137
+ return best_threshold
138
+
139
+ def prepare_submission(model, test_df, window_size, threshold=0.5):
140
+ test_dataset = InferenceMicrographDataset(test_df, window_size=window_size)
141
+ predictions = []
142
+ model.eval()
143
+
144
+ with torch.no_grad():
145
+ for idx in tqdm.tqdm(range(len(test_dataset))):
146
+ image, image_id, (pad_h, pad_w) = test_dataset[idx]
147
+ pred = model(image.unsqueeze(0)).squeeze().cpu().numpy()
148
+ if pad_h > 0: pred = pred[:-pad_h,:]
149
+ if pad_w > 0: pred = pred[:,:-pad_w]
150
+ pred_mask = (pred > threshold).astype(np.uint8)
151
+ encoded_pred = encode_array(pred_mask)
152
+ predictions.append({'Id': image_id, 'mask': encoded_pred})
153
+
154
+ submission_df = pd.DataFrame(predictions)
155
+ return submission_df
AI Cancer Cell Types Requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pandas
5
+ pytorch-lightning
6
+ gradio
7
+ matplotlib
8
+ scipy
9
+ scikit-learn
10
+ tqdm
11
+ kaggle