ianpan commited on
Commit
041fdf1
·
1 Parent(s): 998c871

use huggingface models

Browse files
crop.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e43631e45b61a439a3fc9d78b21501a92de8ef67a33ef050d44476f7153e6fae
3
- size 6228872
 
 
 
 
net0.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5a60765870a852fd71e1219d895f18ec8f9272a9c785b291b1bed29746d7e42c
3
- size 112286108
 
 
 
 
net1.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8f703131880526db7942a55b7867d417f989827fecf3b5b9d11077f6216ee6aa
3
- size 112286108
 
 
 
 
net2.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9f9bb1365149d9570272af3aa8c059e4d33cd3a6c3d03b82cf923e868caeaea0
3
- size 112286108
 
 
 
 
skp/configs/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- from types import SimpleNamespace
2
-
3
-
4
- class Config(SimpleNamespace):
5
-
6
- def __getattribute__(self, value):
7
- # If attribute not specified in config,
8
- # return None instead of raise error
9
- try:
10
- return super().__getattribute__(value)
11
- except AttributeError:
12
- return None
13
-
14
- def __str__(self):
15
- # pretty print
16
- string = ["config"]
17
- string.append("=" * len(string[0]))
18
- longest_param_name = max([len(k) for k in [*self.__dict__]])
19
- for k, v in self.__dict__.items():
20
- string.append(f"{k.ljust(longest_param_name)} : {v}")
21
- return "\n".join(string)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/base.py DELETED
@@ -1,21 +0,0 @@
1
- from types import SimpleNamespace
2
-
3
-
4
- class Config(SimpleNamespace):
5
-
6
- def __getattribute__(self, value):
7
- # If attribute not specified in config,
8
- # return None instead of raise error
9
- try:
10
- return super().__getattribute__(value)
11
- except AttribuateError:
12
- return None
13
-
14
- def __str__(self):
15
- # pretty print
16
- string = ["config"]
17
- string.append("=" * len(string[0]))
18
- longest_param_name = max([len(k) for k in [*self.__dict__]])
19
- for k, v in self.__dict__.items():
20
- string.append(f"{k.ljust(longest_param_name)} : {v}")
21
- return "\n".join(string)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_baseline.py DELETED
@@ -1,117 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d_var_embed"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.embed_num_classes = 2
18
- cfg.embed_dim = 32
19
- cfg.pretrained = True
20
- cfg.num_input_channels = 1
21
- cfg.pool = "gem"
22
- cfg.pool_params = {"p": 3}
23
- cfg.dropout = 0.1
24
- cfg.num_classes = 1
25
- cfg.normalization = "-1_1"
26
- cfg.normalization_params = {"min": 0, "max": 255}
27
- cfg.backbone_img_size = False
28
-
29
- cfg.fold = 0
30
- cfg.dataset = "simple2d"
31
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
32
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
33
- cfg.inputs = "imgfile0"
34
- cfg.targets = ["bone_age_years"]
35
- cfg.vars = "female"
36
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
37
- cfg.num_workers = 16
38
- cfg.pin_memory = True
39
- cfg.persistent_workers = True
40
- cfg.sampler = "IterationBasedSampler"
41
- cfg.num_iterations_per_epoch = 1000
42
-
43
- cfg.loss = "classification.L1Loss"
44
- cfg.loss_params = {}
45
-
46
- cfg.batch_size = 32
47
- cfg.num_epochs = 10
48
- cfg.optimizer = "AdamW"
49
- cfg.optimizer_params = {"lr": 3e-4}
50
-
51
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
52
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
53
- cfg.scheduler_interval = "step"
54
-
55
- cfg.val_batch_size = cfg.batch_size * 2
56
- cfg.metrics = ["classification.MAE", "classification.MSE"]
57
- cfg.val_metric = "mae_mean"
58
- cfg.val_track = "min"
59
-
60
- cfg.image_height = 512
61
- cfg.image_width = 512
62
-
63
- resize_transforms = [
64
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
65
- A.PadIfNeeded(
66
- min_height=cfg.image_height,
67
- min_width=cfg.image_width,
68
- border_mode=cv2.BORDER_CONSTANT,
69
- p=1,
70
- ),
71
- ]
72
-
73
- cfg.train_transforms = A.Compose(
74
- resize_transforms
75
- + [
76
- A.VerticalFlip(p=0.5),
77
- A.HorizontalFlip(p=0.5),
78
- A.SomeOf(
79
- [
80
- A.ShiftScaleRotate(
81
- shift_limit=0.2,
82
- scale_limit=0.0,
83
- rotate_limit=0,
84
- border_mode=cv2.BORDER_CONSTANT,
85
- p=1,
86
- ),
87
- A.ShiftScaleRotate(
88
- shift_limit=0.0,
89
- scale_limit=0.2,
90
- rotate_limit=0,
91
- border_mode=cv2.BORDER_CONSTANT,
92
- p=1,
93
- ),
94
- A.ShiftScaleRotate(
95
- shift_limit=0.0,
96
- scale_limit=0.0,
97
- rotate_limit=30,
98
- border_mode=cv2.BORDER_CONSTANT,
99
- p=1,
100
- ),
101
- A.GaussianBlur(p=1),
102
- A.GaussNoise(p=1),
103
- A.RandomBrightnessContrast(
104
- contrast_limit=0.3, brightness_limit=0.0, p=1
105
- ),
106
- A.RandomBrightnessContrast(
107
- contrast_limit=0.0, brightness_limit=0.3, p=1
108
- ),
109
- ],
110
- n=3,
111
- p=0.9,
112
- replace=False,
113
- ),
114
- ]
115
- )
116
-
117
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_crop.py DELETED
@@ -1,123 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d"
16
- cfg.backbone = "mobilenetv3_small_100"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 1
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = 4
23
- cfg.normalization = "-1_1"
24
- cfg.normalization_params = {"min": 0, "max": 255}
25
- cfg.backbone_img_size = False
26
- cfg.model_activation_fn = "sigmoid"
27
-
28
- cfg.fold = 0
29
- cfg.dataset = "crop2d"
30
- cfg.data_dir = "/mnt/stor/datasets/bone-age/train/"
31
- cfg.annotations_file = (
32
- "/mnt/stor/datasets/bone-age/train_with_bounding_box_crop_coords_kfold.csv"
33
- )
34
- cfg.inputs = "imgfile"
35
- cfg.targets = ["x1", "y1", "w", "h"]
36
- cfg.normalize_crop_coords = True
37
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
38
- cfg.num_workers = 16
39
- cfg.pin_memory = True
40
- cfg.persistent_workers = True
41
- cfg.sampler = "IterationBasedSampler"
42
- cfg.num_iterations_per_epoch = 100
43
-
44
- cfg.loss = "classification.L1Loss"
45
- cfg.loss_params = {}
46
-
47
- cfg.batch_size = 16
48
- cfg.num_epochs = 10
49
- cfg.optimizer = "AdamW"
50
- cfg.optimizer_params = {"lr": 3e-4}
51
-
52
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
53
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
54
- cfg.scheduler_interval = "step"
55
-
56
- cfg.val_batch_size = cfg.batch_size * 2
57
- cfg.metrics = ["classification.MAE", "classification.MSE"]
58
- cfg.val_metric = "mae_mean"
59
- cfg.val_track = "min"
60
-
61
- cfg.image_height = 512
62
- cfg.image_width = 512
63
-
64
- bbox_params = A.BboxParams(format="coco")
65
- resize_transforms = [
66
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
67
- A.PadIfNeeded(
68
- min_height=cfg.image_height,
69
- min_width=cfg.image_width,
70
- border_mode=cv2.BORDER_CONSTANT,
71
- p=1,
72
- ),
73
- ]
74
-
75
- cfg.train_transforms = A.Compose(
76
- resize_transforms
77
- + [
78
- A.VerticalFlip(p=0.5),
79
- A.HorizontalFlip(p=0.5),
80
- A.SomeOf(
81
- [
82
- A.ShiftScaleRotate(
83
- shift_limit=0.2,
84
- scale_limit=0.0,
85
- rotate_limit=0,
86
- border_mode=cv2.BORDER_CONSTANT,
87
- p=1,
88
- ),
89
- A.ShiftScaleRotate(
90
- shift_limit=0.0,
91
- scale_limit=0.2,
92
- rotate_limit=0,
93
- border_mode=cv2.BORDER_CONSTANT,
94
- p=1,
95
- ),
96
- A.ShiftScaleRotate(
97
- shift_limit=0.0,
98
- scale_limit=0.0,
99
- rotate_limit=30,
100
- border_mode=cv2.BORDER_CONSTANT,
101
- p=1,
102
- ),
103
- A.GaussianBlur(p=1),
104
- A.GaussNoise(p=1),
105
- A.RandomBrightnessContrast(
106
- contrast_limit=0.3, brightness_limit=0.0, p=1
107
- ),
108
- A.RandomBrightnessContrast(
109
- contrast_limit=0.0, brightness_limit=0.3, p=1
110
- ),
111
- ],
112
- n=3,
113
- p=0.9,
114
- replace=False,
115
- ),
116
- ],
117
- bbox_params=bbox_params,
118
- )
119
-
120
- cfg.val_transforms = A.Compose(
121
- resize_transforms,
122
- bbox_params=bbox_params,
123
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_crop_simple_resize.py DELETED
@@ -1,117 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d"
16
- cfg.backbone = "mobilenetv3_small_100"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 1
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = 4
23
- cfg.normalization = "-1_1"
24
- cfg.normalization_params = {"min": 0, "max": 255}
25
- cfg.backbone_img_size = False
26
- cfg.model_activation_fn = "sigmoid"
27
-
28
- cfg.fold = 0
29
- cfg.dataset = "crop2d"
30
- cfg.data_dir = "/mnt/stor/datasets/bone-age/train/"
31
- cfg.annotations_file = (
32
- "/mnt/stor/datasets/bone-age/train_with_bounding_box_crop_coords_kfold.csv"
33
- )
34
- cfg.inputs = "imgfile"
35
- cfg.targets = ["x1", "y1", "w", "h"]
36
- cfg.normalize_crop_coords = True
37
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
38
- cfg.num_workers = 16
39
- cfg.pin_memory = True
40
- cfg.persistent_workers = True
41
- cfg.sampler = "IterationBasedSampler"
42
- cfg.num_iterations_per_epoch = 200
43
-
44
- cfg.loss = "classification.L1Loss"
45
- cfg.loss_params = {}
46
-
47
- cfg.batch_size = 16
48
- cfg.num_epochs = 10
49
- cfg.optimizer = "AdamW"
50
- cfg.optimizer_params = {"lr": 3e-4}
51
-
52
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
53
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
54
- cfg.scheduler_interval = "step"
55
-
56
- cfg.val_batch_size = cfg.batch_size * 2
57
- cfg.metrics = ["classification.MAE", "classification.MSE"]
58
- cfg.val_metric = "mae_mean"
59
- cfg.val_track = "min"
60
-
61
- cfg.image_height = 512
62
- cfg.image_width = 512
63
-
64
- bbox_params = A.BboxParams(format="coco")
65
- resize_transforms = [
66
- A.Resize(height=cfg.image_height, width=cfg.image_width, p=1)
67
- ]
68
-
69
- cfg.train_transforms = A.Compose(
70
- resize_transforms
71
- + [
72
- A.VerticalFlip(p=0.5),
73
- A.HorizontalFlip(p=0.5),
74
- A.SomeOf(
75
- [
76
- A.ShiftScaleRotate(
77
- shift_limit=0.2,
78
- scale_limit=0.0,
79
- rotate_limit=0,
80
- border_mode=cv2.BORDER_CONSTANT,
81
- p=1,
82
- ),
83
- A.ShiftScaleRotate(
84
- shift_limit=0.0,
85
- scale_limit=0.2,
86
- rotate_limit=0,
87
- border_mode=cv2.BORDER_CONSTANT,
88
- p=1,
89
- ),
90
- A.ShiftScaleRotate(
91
- shift_limit=0.0,
92
- scale_limit=0.0,
93
- rotate_limit=30,
94
- border_mode=cv2.BORDER_CONSTANT,
95
- p=1,
96
- ),
97
- A.GaussianBlur(p=1),
98
- A.GaussNoise(p=1),
99
- A.RandomBrightnessContrast(
100
- contrast_limit=0.3, brightness_limit=0.0, p=1
101
- ),
102
- A.RandomBrightnessContrast(
103
- contrast_limit=0.0, brightness_limit=0.3, p=1
104
- ),
105
- ],
106
- n=3,
107
- p=0.9,
108
- replace=False,
109
- ),
110
- ],
111
- bbox_params=bbox_params,
112
- )
113
-
114
- cfg.val_transforms = A.Compose(
115
- resize_transforms,
116
- bbox_params=bbox_params,
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel.py DELETED
@@ -1,114 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = 1
23
- cfg.normalization = "-1_1"
24
- cfg.normalization_params = {"min": 0, "max": 255}
25
- cfg.backbone_img_size = False
26
-
27
- cfg.fold = 0
28
- cfg.dataset = "boneage.female_channel"
29
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
30
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
31
- cfg.inputs = "imgfile0"
32
- cfg.targets = ["bone_age_years"]
33
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
34
- cfg.num_workers = 16
35
- cfg.pin_memory = True
36
- cfg.persistent_workers = True
37
- cfg.sampler = "IterationBasedSampler"
38
- cfg.num_iterations_per_epoch = 1000
39
-
40
- cfg.loss = "classification.L1Loss"
41
- cfg.loss_params = {}
42
-
43
- cfg.batch_size = 32
44
- cfg.num_epochs = 10
45
- cfg.optimizer = "AdamW"
46
- cfg.optimizer_params = {"lr": 3e-4}
47
-
48
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
49
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
50
- cfg.scheduler_interval = "step"
51
-
52
- cfg.val_batch_size = cfg.batch_size * 2
53
- cfg.metrics = ["classification.MAE", "classification.MSE"]
54
- cfg.val_metric = "mae_mean"
55
- cfg.val_track = "min"
56
-
57
- cfg.image_height = 512
58
- cfg.image_width = 512
59
-
60
- resize_transforms = [
61
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
62
- A.PadIfNeeded(
63
- min_height=cfg.image_height,
64
- min_width=cfg.image_width,
65
- border_mode=cv2.BORDER_CONSTANT,
66
- p=1,
67
- ),
68
- ]
69
-
70
- cfg.train_transforms = A.Compose(
71
- resize_transforms
72
- + [
73
- A.VerticalFlip(p=0.5),
74
- A.HorizontalFlip(p=0.5),
75
- A.SomeOf(
76
- [
77
- A.ShiftScaleRotate(
78
- shift_limit=0.2,
79
- scale_limit=0.0,
80
- rotate_limit=0,
81
- border_mode=cv2.BORDER_CONSTANT,
82
- p=1,
83
- ),
84
- A.ShiftScaleRotate(
85
- shift_limit=0.0,
86
- scale_limit=0.2,
87
- rotate_limit=0,
88
- border_mode=cv2.BORDER_CONSTANT,
89
- p=1,
90
- ),
91
- A.ShiftScaleRotate(
92
- shift_limit=0.0,
93
- scale_limit=0.0,
94
- rotate_limit=30,
95
- border_mode=cv2.BORDER_CONSTANT,
96
- p=1,
97
- ),
98
- A.GaussianBlur(p=1),
99
- A.GaussNoise(p=1),
100
- A.RandomBrightnessContrast(
101
- contrast_limit=0.3, brightness_limit=0.0, p=1
102
- ),
103
- A.RandomBrightnessContrast(
104
- contrast_limit=0.0, brightness_limit=0.3, p=1
105
- ),
106
- ],
107
- n=3,
108
- p=0.9,
109
- replace=False,
110
- ),
111
- ]
112
- )
113
-
114
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_MIL.py DELETED
@@ -1,113 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "MIL.net2d_basic_attn"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = 1
23
- cfg.attn_dropout = 0.0
24
- cfg.attn_version = "v1"
25
- cfg.normalization = "-1_1"
26
- cfg.normalization_params = {"min": 0, "max": 255}
27
- cfg.backbone_img_size = False
28
-
29
- cfg.fold = 0
30
- cfg.dataset = "boneage.female_channel_grid_patch"
31
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
32
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
33
- cfg.inputs = "imgfile0"
34
- cfg.targets = ["bone_age_years"]
35
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
36
- cfg.patch_size = 224
37
- cfg.patch_num_rows = 5
38
- cfg.patch_num_cols = 3
39
- cfg.num_workers = 16
40
- cfg.pin_memory = True
41
- cfg.persistent_workers = True
42
- cfg.sampler = "IterationBasedSampler"
43
- cfg.num_iterations_per_epoch = 1000
44
-
45
- cfg.loss = "classification.L1Loss"
46
- cfg.loss_params = {}
47
-
48
- cfg.batch_size = 16
49
- cfg.num_epochs = 10
50
- cfg.optimizer = "AdamW"
51
- cfg.optimizer_params = {"lr": 3e-4}
52
-
53
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
54
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
55
- cfg.scheduler_interval = "step"
56
-
57
- cfg.val_batch_size = cfg.batch_size * 2
58
- cfg.metrics = ["classification.MAE", "classification.MSE"]
59
- cfg.val_metric = "mae_mean"
60
- cfg.val_track = "min"
61
-
62
- cfg.image_height = 560
63
- cfg.image_width = cfg.image_height # not used
64
-
65
- resize_transforms = [
66
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
67
- ]
68
-
69
- cfg.train_transforms = A.Compose(
70
- resize_transforms
71
- + [
72
- A.VerticalFlip(p=0.5),
73
- A.HorizontalFlip(p=0.5),
74
- A.SomeOf(
75
- [
76
- A.ShiftScaleRotate(
77
- shift_limit=0.2,
78
- scale_limit=0.0,
79
- rotate_limit=0,
80
- border_mode=cv2.BORDER_CONSTANT,
81
- p=1,
82
- ),
83
- A.ShiftScaleRotate(
84
- shift_limit=0.0,
85
- scale_limit=0.2,
86
- rotate_limit=0,
87
- border_mode=cv2.BORDER_CONSTANT,
88
- p=1,
89
- ),
90
- A.ShiftScaleRotate(
91
- shift_limit=0.0,
92
- scale_limit=0.0,
93
- rotate_limit=30,
94
- border_mode=cv2.BORDER_CONSTANT,
95
- p=1,
96
- ),
97
- A.GaussianBlur(p=1),
98
- A.GaussNoise(p=1),
99
- A.RandomBrightnessContrast(
100
- contrast_limit=0.3, brightness_limit=0.0, p=1
101
- ),
102
- A.RandomBrightnessContrast(
103
- contrast_limit=0.0, brightness_limit=0.3, p=1
104
- ),
105
- ],
106
- n=3,
107
- p=0.9,
108
- replace=False,
109
- ),
110
- ]
111
- )
112
-
113
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_MIL_lstm.py DELETED
@@ -1,116 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "MIL.net2d_attn"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = 1
23
- cfg.add_lstm = True
24
- cfg.lstm_dropout = 0.0
25
- cfg.lstm_num_layers = 1
26
- cfg.attn_dropout = 0.0
27
- cfg.attn_version = "v1"
28
- cfg.normalization = "-1_1"
29
- cfg.normalization_params = {"min": 0, "max": 255}
30
- cfg.backbone_img_size = False
31
-
32
- cfg.fold = 0
33
- cfg.dataset = "boneage.female_channel_grid_patch"
34
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
35
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
36
- cfg.inputs = "imgfile0"
37
- cfg.targets = ["bone_age_years"]
38
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
39
- cfg.patch_size = 224
40
- cfg.patch_num_rows = 5
41
- cfg.patch_num_cols = 3
42
- cfg.num_workers = 16
43
- cfg.pin_memory = True
44
- cfg.persistent_workers = True
45
- cfg.sampler = "IterationBasedSampler"
46
- cfg.num_iterations_per_epoch = 1000
47
-
48
- cfg.loss = "classification.L1Loss"
49
- cfg.loss_params = {}
50
-
51
- cfg.batch_size = 16
52
- cfg.num_epochs = 10
53
- cfg.optimizer = "AdamW"
54
- cfg.optimizer_params = {"lr": 3e-4}
55
-
56
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
57
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
58
- cfg.scheduler_interval = "step"
59
-
60
- cfg.val_batch_size = cfg.batch_size * 2
61
- cfg.metrics = ["classification.MAE", "classification.MSE"]
62
- cfg.val_metric = "mae_mean"
63
- cfg.val_track = "min"
64
-
65
- cfg.image_height = 560
66
- cfg.image_width = cfg.image_height # not used
67
-
68
- resize_transforms = [
69
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
70
- ]
71
-
72
- cfg.train_transforms = A.Compose(
73
- resize_transforms
74
- + [
75
- A.VerticalFlip(p=0.5),
76
- A.HorizontalFlip(p=0.5),
77
- A.SomeOf(
78
- [
79
- A.ShiftScaleRotate(
80
- shift_limit=0.2,
81
- scale_limit=0.0,
82
- rotate_limit=0,
83
- border_mode=cv2.BORDER_CONSTANT,
84
- p=1,
85
- ),
86
- A.ShiftScaleRotate(
87
- shift_limit=0.0,
88
- scale_limit=0.2,
89
- rotate_limit=0,
90
- border_mode=cv2.BORDER_CONSTANT,
91
- p=1,
92
- ),
93
- A.ShiftScaleRotate(
94
- shift_limit=0.0,
95
- scale_limit=0.0,
96
- rotate_limit=30,
97
- border_mode=cv2.BORDER_CONSTANT,
98
- p=1,
99
- ),
100
- A.GaussianBlur(p=1),
101
- A.GaussNoise(p=1),
102
- A.RandomBrightnessContrast(
103
- contrast_limit=0.3, brightness_limit=0.0, p=1
104
- ),
105
- A.RandomBrightnessContrast(
106
- contrast_limit=0.0, brightness_limit=0.3, p=1
107
- ),
108
- ],
109
- n=3,
110
- p=0.9,
111
- replace=False,
112
- ),
113
- ]
114
- )
115
-
116
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_MIL_transformer.py DELETED
@@ -1,117 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "MIL.net2d_attn"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = 1
23
- cfg.reduce_feature_dim = 256
24
- cfg.add_transformer = True
25
- cfg.transformer_dropout = 0.0
26
- cfg.transformer_num_layers = 1
27
- cfg.attn_dropout = 0.0
28
- cfg.attn_version = "v1"
29
- cfg.normalization = "-1_1"
30
- cfg.normalization_params = {"min": 0, "max": 255}
31
- cfg.backbone_img_size = False
32
-
33
- cfg.fold = 0
34
- cfg.dataset = "boneage.female_channel_grid_patch"
35
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
36
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
37
- cfg.inputs = "imgfile0"
38
- cfg.targets = ["bone_age_years"]
39
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
40
- cfg.patch_size = 224
41
- cfg.patch_num_rows = 5
42
- cfg.patch_num_cols = 3
43
- cfg.num_workers = 16
44
- cfg.pin_memory = True
45
- cfg.persistent_workers = True
46
- cfg.sampler = "IterationBasedSampler"
47
- cfg.num_iterations_per_epoch = 1000
48
-
49
- cfg.loss = "classification.L1Loss"
50
- cfg.loss_params = {}
51
-
52
- cfg.batch_size = 16
53
- cfg.num_epochs = 10
54
- cfg.optimizer = "AdamW"
55
- cfg.optimizer_params = {"lr": 3e-4}
56
-
57
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
58
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
59
- cfg.scheduler_interval = "step"
60
-
61
- cfg.val_batch_size = cfg.batch_size * 2
62
- cfg.metrics = ["classification.MAE", "classification.MSE"]
63
- cfg.val_metric = "mae_mean"
64
- cfg.val_track = "min"
65
-
66
- cfg.image_height = 560
67
- cfg.image_width = cfg.image_height # not used
68
-
69
- resize_transforms = [
70
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
71
- ]
72
-
73
- cfg.train_transforms = A.Compose(
74
- resize_transforms
75
- + [
76
- A.VerticalFlip(p=0.5),
77
- A.HorizontalFlip(p=0.5),
78
- A.SomeOf(
79
- [
80
- A.ShiftScaleRotate(
81
- shift_limit=0.2,
82
- scale_limit=0.0,
83
- rotate_limit=0,
84
- border_mode=cv2.BORDER_CONSTANT,
85
- p=1,
86
- ),
87
- A.ShiftScaleRotate(
88
- shift_limit=0.0,
89
- scale_limit=0.2,
90
- rotate_limit=0,
91
- border_mode=cv2.BORDER_CONSTANT,
92
- p=1,
93
- ),
94
- A.ShiftScaleRotate(
95
- shift_limit=0.0,
96
- scale_limit=0.0,
97
- rotate_limit=30,
98
- border_mode=cv2.BORDER_CONSTANT,
99
- p=1,
100
- ),
101
- A.GaussianBlur(p=1),
102
- A.GaussNoise(p=1),
103
- A.RandomBrightnessContrast(
104
- contrast_limit=0.3, brightness_limit=0.0, p=1
105
- ),
106
- A.RandomBrightnessContrast(
107
- contrast_limit=0.0, brightness_limit=0.3, p=1
108
- ),
109
- ],
110
- n=3,
111
- p=0.9,
112
- replace=False,
113
- ),
114
- ]
115
- )
116
-
117
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_reg_cls.py DELETED
@@ -1,115 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d_multihead"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = [1, 240]
23
- cfg.num_heads = 2
24
- cfg.normalization = "-1_1"
25
- cfg.normalization_params = {"min": 0, "max": 255}
26
- cfg.backbone_img_size = False
27
-
28
- cfg.fold = 0
29
- cfg.dataset = "boneage.female_channel"
30
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
31
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
32
- cfg.inputs = "imgfile0"
33
- cfg.targets = ["bone_age"]
34
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
35
- cfg.num_workers = 16
36
- cfg.pin_memory = True
37
- cfg.persistent_workers = True
38
- cfg.sampler = "IterationBasedSampler"
39
- cfg.num_iterations_per_epoch = 1000
40
-
41
- cfg.loss = "classification.DoubleL1Loss"
42
- cfg.loss_params = {"reg_weight": 1.0, "cls_weight": 0.4}
43
-
44
- cfg.batch_size = 32
45
- cfg.num_epochs = 10
46
- cfg.optimizer = "AdamW"
47
- cfg.optimizer_params = {"lr": 3e-4}
48
-
49
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
50
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
51
- cfg.scheduler_interval = "step"
52
-
53
- cfg.val_batch_size = cfg.batch_size * 2
54
- cfg.metrics = ["classification.DoubleMAE"]
55
- cfg.val_metric = "mae_reg"
56
- cfg.val_track = "min"
57
-
58
- cfg.image_height = 512
59
- cfg.image_width = 512
60
-
61
- resize_transforms = [
62
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
63
- A.PadIfNeeded(
64
- min_height=cfg.image_height,
65
- min_width=cfg.image_width,
66
- border_mode=cv2.BORDER_CONSTANT,
67
- p=1,
68
- ),
69
- ]
70
-
71
- cfg.train_transforms = A.Compose(
72
- resize_transforms
73
- + [
74
- A.VerticalFlip(p=0.5),
75
- A.HorizontalFlip(p=0.5),
76
- A.SomeOf(
77
- [
78
- A.ShiftScaleRotate(
79
- shift_limit=0.2,
80
- scale_limit=0.0,
81
- rotate_limit=0,
82
- border_mode=cv2.BORDER_CONSTANT,
83
- p=1,
84
- ),
85
- A.ShiftScaleRotate(
86
- shift_limit=0.0,
87
- scale_limit=0.2,
88
- rotate_limit=0,
89
- border_mode=cv2.BORDER_CONSTANT,
90
- p=1,
91
- ),
92
- A.ShiftScaleRotate(
93
- shift_limit=0.0,
94
- scale_limit=0.0,
95
- rotate_limit=30,
96
- border_mode=cv2.BORDER_CONSTANT,
97
- p=1,
98
- ),
99
- A.GaussianBlur(p=1),
100
- A.GaussNoise(p=1),
101
- A.RandomBrightnessContrast(
102
- contrast_limit=0.3, brightness_limit=0.0, p=1
103
- ),
104
- A.RandomBrightnessContrast(
105
- contrast_limit=0.0, brightness_limit=0.3, p=1
106
- ),
107
- ],
108
- n=3,
109
- p=0.9,
110
- replace=False,
111
- ),
112
- ]
113
- )
114
-
115
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_reg_cls_clip_outliers_aug.py DELETED
@@ -1,119 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d_multihead"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = [1, 240]
23
- cfg.num_heads = 2
24
- cfg.normalization = "-1_1"
25
- cfg.normalization_params = {"min": 0, "max": 255}
26
- cfg.backbone_img_size = False
27
-
28
- cfg.fold = 0
29
- cfg.dataset = "boneage.female_channel"
30
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
31
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
32
- cfg.inputs = "imgfile0"
33
- cfg.targets = ["bone_age"]
34
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
35
- cfg.num_workers = 16
36
- cfg.clip_outlier_pixels_and_rescale = True
37
- cfg.clip_as_data_aug = True
38
- cfg.clip_proba = 0.5
39
- cfg.clip_bounds = (1, 99)
40
- cfg.pin_memory = True
41
- cfg.persistent_workers = True
42
- cfg.sampler = "IterationBasedSampler"
43
- cfg.num_iterations_per_epoch = 1000
44
-
45
- cfg.loss = "classification.DoubleL1Loss"
46
- cfg.loss_params = {"reg_weight": 1.0, "cls_weight": 0.4}
47
-
48
- cfg.batch_size = 32
49
- cfg.num_epochs = 10
50
- cfg.optimizer = "AdamW"
51
- cfg.optimizer_params = {"lr": 3e-4}
52
-
53
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
54
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
55
- cfg.scheduler_interval = "step"
56
-
57
- cfg.val_batch_size = cfg.batch_size * 2
58
- cfg.metrics = ["classification.DoubleMAE"]
59
- cfg.val_metric = "mae_reg"
60
- cfg.val_track = "min"
61
-
62
- cfg.image_height = 512
63
- cfg.image_width = 512
64
-
65
- resize_transforms = [
66
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
67
- A.PadIfNeeded(
68
- min_height=cfg.image_height,
69
- min_width=cfg.image_width,
70
- border_mode=cv2.BORDER_CONSTANT,
71
- p=1,
72
- ),
73
- ]
74
-
75
- cfg.train_transforms = A.Compose(
76
- resize_transforms
77
- + [
78
- A.VerticalFlip(p=0.5),
79
- A.HorizontalFlip(p=0.5),
80
- A.SomeOf(
81
- [
82
- A.ShiftScaleRotate(
83
- shift_limit=0.2,
84
- scale_limit=0.0,
85
- rotate_limit=0,
86
- border_mode=cv2.BORDER_CONSTANT,
87
- p=1,
88
- ),
89
- A.ShiftScaleRotate(
90
- shift_limit=0.0,
91
- scale_limit=0.2,
92
- rotate_limit=0,
93
- border_mode=cv2.BORDER_CONSTANT,
94
- p=1,
95
- ),
96
- A.ShiftScaleRotate(
97
- shift_limit=0.0,
98
- scale_limit=0.0,
99
- rotate_limit=30,
100
- border_mode=cv2.BORDER_CONSTANT,
101
- p=1,
102
- ),
103
- A.GaussianBlur(p=1),
104
- A.GaussNoise(p=1),
105
- A.RandomBrightnessContrast(
106
- contrast_limit=0.3, brightness_limit=0.0, p=1
107
- ),
108
- A.RandomBrightnessContrast(
109
- contrast_limit=0.0, brightness_limit=0.3, p=1
110
- ),
111
- ],
112
- n=3,
113
- p=0.9,
114
- replace=False,
115
- ),
116
- ]
117
- )
118
-
119
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_reg_cls_match_hist.py DELETED
@@ -1,116 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d_multihead"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = [1, 240]
23
- cfg.num_heads = 2
24
- cfg.normalization = "-1_1"
25
- cfg.normalization_params = {"min": 0, "max": 255}
26
- cfg.backbone_img_size = False
27
-
28
- cfg.fold = 0
29
- cfg.dataset = "boneage.female_channel_match_hist"
30
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
31
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
32
- cfg.ref_image_match_hist = "/mnt/stor/datasets/bone-age/reference_cropped_image_for_histogram_matching.png"
33
- cfg.inputs = "imgfile0"
34
- cfg.targets = ["bone_age"]
35
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
36
- cfg.num_workers = 16
37
- cfg.pin_memory = True
38
- cfg.persistent_workers = True
39
- cfg.sampler = "IterationBasedSampler"
40
- cfg.num_iterations_per_epoch = 1000
41
-
42
- cfg.loss = "classification.DoubleL1Loss"
43
- cfg.loss_params = {"reg_weight": 1.0, "cls_weight": 0.4}
44
-
45
- cfg.batch_size = 32
46
- cfg.num_epochs = 10
47
- cfg.optimizer = "AdamW"
48
- cfg.optimizer_params = {"lr": 3e-4}
49
-
50
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
51
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
52
- cfg.scheduler_interval = "step"
53
-
54
- cfg.val_batch_size = cfg.batch_size * 2
55
- cfg.metrics = ["classification.DoubleMAE"]
56
- cfg.val_metric = "mae_reg"
57
- cfg.val_track = "min"
58
-
59
- cfg.image_height = 512
60
- cfg.image_width = 512
61
-
62
- resize_transforms = [
63
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
64
- A.PadIfNeeded(
65
- min_height=cfg.image_height,
66
- min_width=cfg.image_width,
67
- border_mode=cv2.BORDER_CONSTANT,
68
- p=1,
69
- ),
70
- ]
71
-
72
- cfg.train_transforms = A.Compose(
73
- resize_transforms
74
- + [
75
- A.VerticalFlip(p=0.5),
76
- A.HorizontalFlip(p=0.5),
77
- A.SomeOf(
78
- [
79
- A.ShiftScaleRotate(
80
- shift_limit=0.2,
81
- scale_limit=0.0,
82
- rotate_limit=0,
83
- border_mode=cv2.BORDER_CONSTANT,
84
- p=1,
85
- ),
86
- A.ShiftScaleRotate(
87
- shift_limit=0.0,
88
- scale_limit=0.2,
89
- rotate_limit=0,
90
- border_mode=cv2.BORDER_CONSTANT,
91
- p=1,
92
- ),
93
- A.ShiftScaleRotate(
94
- shift_limit=0.0,
95
- scale_limit=0.0,
96
- rotate_limit=30,
97
- border_mode=cv2.BORDER_CONSTANT,
98
- p=1,
99
- ),
100
- A.GaussianBlur(p=1),
101
- A.GaussNoise(p=1),
102
- A.RandomBrightnessContrast(
103
- contrast_limit=0.3, brightness_limit=0.0, p=1
104
- ),
105
- A.RandomBrightnessContrast(
106
- contrast_limit=0.0, brightness_limit=0.3, p=1
107
- ),
108
- ],
109
- n=3,
110
- p=0.9,
111
- replace=False,
112
- ),
113
- ]
114
- )
115
-
116
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_with_cls.py DELETED
@@ -1,115 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d_multihead"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = [1, 24]
23
- cfg.num_heads = 2
24
- cfg.normalization = "-1_1"
25
- cfg.normalization_params = {"min": 0, "max": 255}
26
- cfg.backbone_img_size = False
27
-
28
- cfg.fold = 0
29
- cfg.dataset = "boneage.female_channel_with_cls"
30
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
31
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
32
- cfg.inputs = "imgfile0"
33
- cfg.targets = ["bone_age_years", "bone_age_categorical"]
34
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
35
- cfg.num_workers = 16
36
- cfg.pin_memory = True
37
- cfg.persistent_workers = True
38
- cfg.sampler = "IterationBasedSampler"
39
- cfg.num_iterations_per_epoch = 1000
40
-
41
- cfg.loss = "classification.L1CELoss"
42
- cfg.loss_params = {"l1_weight": 1.0, "ce_weight": 0.2}
43
-
44
- cfg.batch_size = 32
45
- cfg.num_epochs = 10
46
- cfg.optimizer = "AdamW"
47
- cfg.optimizer_params = {"lr": 3e-4}
48
-
49
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
50
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
51
- cfg.scheduler_interval = "step"
52
-
53
- cfg.val_batch_size = cfg.batch_size * 2
54
- cfg.metrics = ["classification.MAE_Accuracy"]
55
- cfg.val_metric = "mae_mean"
56
- cfg.val_track = "min"
57
-
58
- cfg.image_height = 512
59
- cfg.image_width = 512
60
-
61
- resize_transforms = [
62
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
63
- A.PadIfNeeded(
64
- min_height=cfg.image_height,
65
- min_width=cfg.image_width,
66
- border_mode=cv2.BORDER_CONSTANT,
67
- p=1,
68
- ),
69
- ]
70
-
71
- cfg.train_transforms = A.Compose(
72
- resize_transforms
73
- + [
74
- A.VerticalFlip(p=0.5),
75
- A.HorizontalFlip(p=0.5),
76
- A.SomeOf(
77
- [
78
- A.ShiftScaleRotate(
79
- shift_limit=0.2,
80
- scale_limit=0.0,
81
- rotate_limit=0,
82
- border_mode=cv2.BORDER_CONSTANT,
83
- p=1,
84
- ),
85
- A.ShiftScaleRotate(
86
- shift_limit=0.0,
87
- scale_limit=0.2,
88
- rotate_limit=0,
89
- border_mode=cv2.BORDER_CONSTANT,
90
- p=1,
91
- ),
92
- A.ShiftScaleRotate(
93
- shift_limit=0.0,
94
- scale_limit=0.0,
95
- rotate_limit=30,
96
- border_mode=cv2.BORDER_CONSTANT,
97
- p=1,
98
- ),
99
- A.GaussianBlur(p=1),
100
- A.GaussNoise(p=1),
101
- A.RandomBrightnessContrast(
102
- contrast_limit=0.3, brightness_limit=0.0, p=1
103
- ),
104
- A.RandomBrightnessContrast(
105
- contrast_limit=0.0, brightness_limit=0.3, p=1
106
- ),
107
- ],
108
- n=3,
109
- p=0.9,
110
- replace=False,
111
- ),
112
- ]
113
- )
114
-
115
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_with_cls_clip_outliers.py DELETED
@@ -1,117 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d_multihead"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = [1, 24]
23
- cfg.num_heads = 2
24
- cfg.normalization = "-1_1"
25
- cfg.normalization_params = {"min": 0, "max": 255}
26
- cfg.backbone_img_size = False
27
-
28
- cfg.fold = 0
29
- cfg.dataset = "boneage.female_channel_with_cls"
30
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
31
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
32
- cfg.inputs = "imgfile0"
33
- cfg.targets = ["bone_age_years", "bone_age_categorical"]
34
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
35
- cfg.num_workers = 16
36
- cfg.clip_outlier_pixels_and_rescale = True
37
- cfg.clip_bounds = (1, 99)
38
- cfg.pin_memory = True
39
- cfg.persistent_workers = True
40
- cfg.sampler = "IterationBasedSampler"
41
- cfg.num_iterations_per_epoch = 1000
42
-
43
- cfg.loss = "classification.L1CELoss"
44
- cfg.loss_params = {"l1_weight": 1.0, "ce_weight": 0.2}
45
-
46
- cfg.batch_size = 32
47
- cfg.num_epochs = 10
48
- cfg.optimizer = "AdamW"
49
- cfg.optimizer_params = {"lr": 3e-4}
50
-
51
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
52
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
53
- cfg.scheduler_interval = "step"
54
-
55
- cfg.val_batch_size = cfg.batch_size * 2
56
- cfg.metrics = ["classification.MAE_Accuracy"]
57
- cfg.val_metric = "mae_mean"
58
- cfg.val_track = "min"
59
-
60
- cfg.image_height = 512
61
- cfg.image_width = 512
62
-
63
- resize_transforms = [
64
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
65
- A.PadIfNeeded(
66
- min_height=cfg.image_height,
67
- min_width=cfg.image_width,
68
- border_mode=cv2.BORDER_CONSTANT,
69
- p=1,
70
- ),
71
- ]
72
-
73
- cfg.train_transforms = A.Compose(
74
- resize_transforms
75
- + [
76
- A.VerticalFlip(p=0.5),
77
- A.HorizontalFlip(p=0.5),
78
- A.SomeOf(
79
- [
80
- A.ShiftScaleRotate(
81
- shift_limit=0.2,
82
- scale_limit=0.0,
83
- rotate_limit=0,
84
- border_mode=cv2.BORDER_CONSTANT,
85
- p=1,
86
- ),
87
- A.ShiftScaleRotate(
88
- shift_limit=0.0,
89
- scale_limit=0.2,
90
- rotate_limit=0,
91
- border_mode=cv2.BORDER_CONSTANT,
92
- p=1,
93
- ),
94
- A.ShiftScaleRotate(
95
- shift_limit=0.0,
96
- scale_limit=0.0,
97
- rotate_limit=30,
98
- border_mode=cv2.BORDER_CONSTANT,
99
- p=1,
100
- ),
101
- A.GaussianBlur(p=1),
102
- A.GaussNoise(p=1),
103
- A.RandomBrightnessContrast(
104
- contrast_limit=0.3, brightness_limit=0.0, p=1
105
- ),
106
- A.RandomBrightnessContrast(
107
- contrast_limit=0.0, brightness_limit=0.3, p=1
108
- ),
109
- ],
110
- n=3,
111
- p=0.9,
112
- replace=False,
113
- ),
114
- ]
115
- )
116
-
117
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/configs/boneage/cfg_female_channel_with_cls_clip_outliers_aug.py DELETED
@@ -1,119 +0,0 @@
1
- import albumentations as A
2
- import cv2
3
-
4
- from skp.configs import Config
5
-
6
-
7
- cfg = Config()
8
- cfg.neptune_mode = "async"
9
-
10
- cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/"
11
- cfg.project = "gradientecho/SKP"
12
-
13
- cfg.task = "classification"
14
-
15
- cfg.model = "classification.net2d_multihead"
16
- cfg.backbone = "tf_efficientnetv2_s"
17
- cfg.pretrained = True
18
- cfg.num_input_channels = 2
19
- cfg.pool = "gem"
20
- cfg.pool_params = {"p": 3}
21
- cfg.dropout = 0.1
22
- cfg.num_classes = [1, 24]
23
- cfg.num_heads = 2
24
- cfg.normalization = "-1_1"
25
- cfg.normalization_params = {"min": 0, "max": 255}
26
- cfg.backbone_img_size = False
27
-
28
- cfg.fold = 0
29
- cfg.dataset = "boneage.female_channel_with_cls"
30
- cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/"
31
- cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv"
32
- cfg.inputs = "imgfile0"
33
- cfg.targets = ["bone_age_years", "bone_age_categorical"]
34
- cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE
35
- cfg.num_workers = 16
36
- cfg.clip_outlier_pixels_and_rescale = True
37
- cfg.clip_as_data_aug = True
38
- cfg.clip_proba = 0.5
39
- cfg.clip_bounds = (1, 99)
40
- cfg.pin_memory = True
41
- cfg.persistent_workers = True
42
- cfg.sampler = "IterationBasedSampler"
43
- cfg.num_iterations_per_epoch = 1000
44
-
45
- cfg.loss = "classification.L1CELoss"
46
- cfg.loss_params = {"l1_weight": 1.0, "ce_weight": 0.2}
47
-
48
- cfg.batch_size = 32
49
- cfg.num_epochs = 10
50
- cfg.optimizer = "AdamW"
51
- cfg.optimizer_params = {"lr": 3e-4}
52
-
53
- cfg.scheduler = "LinearWarmupCosineAnnealingLR"
54
- cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000}
55
- cfg.scheduler_interval = "step"
56
-
57
- cfg.val_batch_size = cfg.batch_size * 2
58
- cfg.metrics = ["classification.MAE_Accuracy"]
59
- cfg.val_metric = "mae_mean"
60
- cfg.val_track = "min"
61
-
62
- cfg.image_height = 512
63
- cfg.image_width = 512
64
-
65
- resize_transforms = [
66
- A.LongestMaxSize(max_size=cfg.image_height, p=1),
67
- A.PadIfNeeded(
68
- min_height=cfg.image_height,
69
- min_width=cfg.image_width,
70
- border_mode=cv2.BORDER_CONSTANT,
71
- p=1,
72
- ),
73
- ]
74
-
75
- cfg.train_transforms = A.Compose(
76
- resize_transforms
77
- + [
78
- A.VerticalFlip(p=0.5),
79
- A.HorizontalFlip(p=0.5),
80
- A.SomeOf(
81
- [
82
- A.ShiftScaleRotate(
83
- shift_limit=0.2,
84
- scale_limit=0.0,
85
- rotate_limit=0,
86
- border_mode=cv2.BORDER_CONSTANT,
87
- p=1,
88
- ),
89
- A.ShiftScaleRotate(
90
- shift_limit=0.0,
91
- scale_limit=0.2,
92
- rotate_limit=0,
93
- border_mode=cv2.BORDER_CONSTANT,
94
- p=1,
95
- ),
96
- A.ShiftScaleRotate(
97
- shift_limit=0.0,
98
- scale_limit=0.0,
99
- rotate_limit=30,
100
- border_mode=cv2.BORDER_CONSTANT,
101
- p=1,
102
- ),
103
- A.GaussianBlur(p=1),
104
- A.GaussNoise(p=1),
105
- A.RandomBrightnessContrast(
106
- contrast_limit=0.3, brightness_limit=0.0, p=1
107
- ),
108
- A.RandomBrightnessContrast(
109
- contrast_limit=0.0, brightness_limit=0.3, p=1
110
- ),
111
- ],
112
- n=3,
113
- p=0.9,
114
- replace=False,
115
- ),
116
- ]
117
- )
118
-
119
- cfg.val_transforms = A.Compose(resize_transforms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/models/MIL/net2d_attn.py DELETED
@@ -1,286 +0,0 @@
1
- """
2
- 2D model for multiple instance learning (MIL)
3
- Performs attention over bag of features (i.e., attention-weighted mean of features)
4
- Option to add LSTM or Transformer before attention aggregation
5
- Uses timm backbones
6
- """
7
-
8
- import re
9
- import torch
10
- import torch.nn as nn
11
-
12
- from einops import rearrange
13
- from timm import create_model
14
- from typing import Dict, Optional, Tuple
15
-
16
- from skp.configs.base import Config
17
- from skp.models.modules import FeatureReduction
18
- from skp.models.pooling import get_pool_layer
19
-
20
-
21
- class Attention(nn.Module):
22
- """
23
- Given a batch containing bags of features (B, N, D),
24
- generate attention scores over the features in a bag, N,
25
- and perform an attention-weighted mean of the features (B, D)
26
- """
27
-
28
- def __init__(self, embed_dim: int, dropout: float = 0.0, version: str = "v1"):
29
- super().__init__()
30
- version = version.lower()
31
- if version == "v1":
32
- self.mlp = nn.Sequential(
33
- nn.Tanh(), nn.Dropout(dropout), nn.Linear(embed_dim, 1)
34
- )
35
- elif version == "v2":
36
- self.mlp = nn.Sequential(
37
- nn.Linear(embed_dim, embed_dim),
38
- nn.Tanh(),
39
- nn.Dropout(dropout),
40
- nn.Linear(embed_dim, 1),
41
- )
42
-
43
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
44
- a = self.mlp(x)
45
- a = a.softmax(dim=1)
46
- x = (x * a).sum(dim=1)
47
- return x, a
48
-
49
-
50
- class BiLSTM(nn.Module):
51
- def __init__(self, embed_dim: int, dropout: float = 0.0, num_layers: int = 1):
52
- super().__init__()
53
- self.lstm = nn.LSTM(
54
- input_size=embed_dim,
55
- hidden_size=embed_dim // 2,
56
- num_layers=num_layers,
57
- bias=True,
58
- batch_first=True,
59
- dropout=dropout,
60
- bidirectional=True,
61
- )
62
-
63
- def forward(self, x: torch.Tensor) -> torch.Tensor:
64
- x, _ = self.lstm(x)
65
- return x
66
-
67
-
68
- class Transformer(nn.Module):
69
- def __init__(
70
- self,
71
- embed_dim: int,
72
- dropout: float = 0.0,
73
- num_layers: int = 1,
74
- nhead: int = 16,
75
- activation: str = "gelu",
76
- ):
77
- super().__init__()
78
- encoder_layer = nn.TransformerEncoderLayer(
79
- d_model=embed_dim,
80
- nhead=nhead,
81
- dim_feedforward=embed_dim,
82
- dropout=dropout,
83
- activation=activation,
84
- batch_first=True,
85
- norm_first=False,
86
- bias=True,
87
- )
88
- self.T = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
89
-
90
- def forward(
91
- self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
92
- ) -> torch.Tensor:
93
- return self.T(x, mask=mask)
94
-
95
-
96
- class Net(nn.Module):
97
- def __init__(self, cfg: Config):
98
- super().__init__()
99
- self.cfg = cfg
100
- backbone_args = {
101
- "pretrained": self.cfg.pretrained,
102
- "num_classes": 0,
103
- "global_pool": "",
104
- "features_only": self.cfg.features_only,
105
- "in_chans": self.cfg.num_input_channels,
106
- }
107
- if self.cfg.backbone_img_size:
108
- # some models require specifying image size (e.g., coatnet)
109
- if "efficientvit" in self.cfg.backbone:
110
- backbone_args["img_size"] = self.cfg.image_height
111
- else:
112
- backbone_args["img_size"] = (
113
- self.cfg.image_height,
114
- self.cfg.image_width,
115
- )
116
- self.backbone = create_model(self.cfg.backbone, **backbone_args)
117
- # get feature dim by passing sample through net
118
- self.feature_dim = self.backbone(
119
- torch.randn(
120
- (
121
- 2,
122
- self.cfg.num_input_channels,
123
- self.cfg.image_height,
124
- self.cfg.image_width,
125
- )
126
- )
127
- ).size(
128
- -1 if "xcit" in self.cfg.backbone else 1
129
- ) # xcit models are channels-last
130
-
131
- self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1)
132
- self.pooling = get_pool_layer(self.cfg, dim=2)
133
-
134
- if isinstance(self.cfg.reduce_feature_dim, int):
135
- self.backbone = nn.Sequential(
136
- self.backbone,
137
- FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim),
138
- )
139
- self.feature_dim = self.cfg.reduce_feature_dim
140
-
141
- if self.cfg.add_lstm:
142
- self.pre_attn = BiLSTM(
143
- embed_dim=self.feature_dim,
144
- dropout=self.cfg.lstm_dropout or 0.0,
145
- num_layers=self.cfg.lstm_num_layers or 1,
146
- )
147
- elif self.cfg.add_transformer:
148
- self.pre_attn = Transformer(
149
- embed_dim=self.feature_dim,
150
- dropout=self.cfg.transformer_dropout or 0.0,
151
- num_layers=self.cfg.transformer_num_layers or 1,
152
- nhead=self.cfg.transformer_nhead or 16,
153
- activation=self.cfg.transformer_act or "gelu",
154
- )
155
- else:
156
- self.pre_attn = nn.Identity()
157
-
158
- self.attn = Attention(
159
- self.feature_dim,
160
- dropout=self.cfg.attn_dropout,
161
- version=self.cfg.attn_version or "v1",
162
- )
163
- self.dropout = nn.Dropout(p=self.cfg.dropout)
164
- self.linear = nn.Linear(self.feature_dim, self.cfg.num_classes)
165
-
166
- if self.cfg.load_pretrained_backbone:
167
- print(
168
- f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..."
169
- )
170
- weights = torch.load(
171
- self.cfg.load_pretrained_backbone,
172
- map_location=lambda storage, loc: storage,
173
- )["state_dict"]
174
- # Replace model prefix as this does not exist in Net
175
- weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()}
176
- # Get backbone only
177
- weights = {
178
- re.sub(r"^backbone.", "", k): v
179
- for k, v in weights.items()
180
- if "backbone" in k
181
- }
182
- self.backbone.load_state_dict(weights)
183
-
184
- self.criterion = None
185
-
186
- self.backbone_frozen = False
187
- if self.cfg.freeze_backbone:
188
- self.freeze_backbone()
189
-
190
- def normalize(self, x: torch.Tensor) -> torch.Tensor:
191
- if self.cfg.normalization == "-1_1":
192
- mini, maxi = (
193
- self.cfg.normalization_params["min"],
194
- self.cfg.normalization_params["max"],
195
- )
196
- x = x - mini
197
- x = x / (maxi - mini)
198
- x = x - 0.5
199
- x = x * 2.0
200
- elif self.cfg.normalization == "0_1":
201
- mini, maxi = (
202
- self.cfg.normalization_params["min"],
203
- self.cfg.normalization_params["max"],
204
- )
205
- x = x - mini
206
- x = x / (maxi - mini)
207
- elif self.cfg.normalization == "mean_sd":
208
- mean, sd = (
209
- self.cfg.normalization_params["mean"],
210
- self.cfg.normalization_params["sd"],
211
- )
212
- x = (x - mean) / sd
213
- elif self.cfg.normalization == "per_channel_mean_sd":
214
- mean, sd = (
215
- self.cfg.normalization_params["mean"],
216
- self.cfg.normalization_params["sd"],
217
- )
218
- assert len(mean) == len(sd) == x.size(1)
219
- mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0)
220
- for i in range(x.ndim - 2):
221
- mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1)
222
- x = (x - mean) / sd
223
- elif self.cfg.normalization == "none":
224
- x = x
225
- return x
226
-
227
- def forward(
228
- self,
229
- batch: Dict,
230
- return_loss: bool = False,
231
- return_features: bool = False,
232
- return_attn_scores: bool = False,
233
- ) -> Dict[str, torch.Tensor]:
234
- x = batch["x"]
235
- y = batch.get("y", None)
236
-
237
- if return_loss:
238
- assert y is not None
239
-
240
- b, n = x.shape[:2]
241
- x = rearrange(x, "b n c h w -> (b n) c h w")
242
- features = self.extract_features(x, normalize=True)
243
- features = rearrange(features, "(b n) d -> b n d", b=b, n=n)
244
- if isinstance(self.pre_attn, Transformer):
245
- features = self.pre_attn(features, mask=batch.get("mask", None))
246
- else:
247
- features = self.pre_attn(features)
248
- features, attn_scores = self.attn(features)
249
-
250
- if self.cfg.multisample_dropout:
251
- logits = torch.stack(
252
- [self.linear(self.dropout(features)) for _ in range(5)]
253
- ).mean(0)
254
- else:
255
- logits = self.linear(self.dropout(features))
256
-
257
- if self.cfg.model_activation_fn == "sigmoid":
258
- logits = logits.sigmoid()
259
- elif self.cfg.model_activation_fn == "softmax":
260
- logits = logits.softmax(dim=1)
261
-
262
- out = {"logits": logits}
263
- if return_features:
264
- out["features"] = features
265
- if return_attn_scores:
266
- out["attn_scores"] = attn_scores
267
- if return_loss:
268
- loss = self.criterion(out, batch)
269
- if isinstance(loss, dict):
270
- out.update(loss)
271
- else:
272
- out["loss"] = loss
273
-
274
- return out
275
-
276
- def extract_features(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor:
277
- x = self.normalize(x) if normalize else x
278
- return self.pooling(self.backbone(x))
279
-
280
- def freeze_backbone(self) -> None:
281
- for param in self.backbone.parameters():
282
- param.requires_grad = False
283
- self.backbone_frozen = True
284
-
285
- def set_criterion(self, loss: nn.Module) -> None:
286
- self.criterion = loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/models/MIL/net2d_basic_attn.py DELETED
@@ -1,284 +0,0 @@
1
- """
2
- 2D model for multiple instance learning (MIL)
3
- Performs attention over bag of features (i.e., attention-weighted mean of features)
4
- Uses timm backbones
5
- """
6
-
7
- import re
8
- import torch
9
- import torch.nn as nn
10
-
11
- from einops import rearrange
12
- from timm import create_model
13
- from typing import Dict, Optional, Tuple
14
-
15
- from skp.configs.base import Config
16
- from skp.models.modules import FeatureReduction
17
- from skp.models.pooling import get_pool_layer
18
-
19
-
20
- class Attention(nn.Module):
21
- """
22
- Given a batch containing bags of features (B, N, D),
23
- generate attention scores over the features in a bag, N,
24
- and perform an attention-weighted mean of the features (B, D)
25
- """
26
-
27
- def __init__(self, embed_dim: int, dropout: float = 0.0, version: str = "v1"):
28
- super().__init__()
29
- version = version.lower()
30
- if version == "v1":
31
- self.mlp = nn.Sequential(
32
- nn.Tanh(), nn.Dropout(dropout), nn.Linear(embed_dim, 1)
33
- )
34
- elif version == "v2":
35
- self.mlp = nn.Sequential(
36
- nn.Linear(embed_dim, embed_dim),
37
- nn.Tanh(),
38
- nn.Dropout(dropout),
39
- nn.Linear(embed_dim, 1),
40
- )
41
-
42
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
43
- a = self.mlp(x)
44
- a = a.softmax(dim=1)
45
- x = (x * a).sum(dim=1)
46
- return x, a
47
-
48
-
49
- class BiLSTM(nn.Module):
50
- def __init__(self, embed_dim: int, dropout: float = 0.0, num_layers: int = 1):
51
- super().__init__()
52
- self.lstm = nn.LSTM(
53
- input_size=embed_dim,
54
- hidden_size=embed_dim // 2,
55
- num_layers=num_layers,
56
- bias=True,
57
- batch_first=True,
58
- dropout=dropout,
59
- bidirectional=True,
60
- )
61
-
62
- def forward(self, x: torch.Tensor) -> torch.Tensor:
63
- x, _ = self.lstm(x)
64
- return x
65
-
66
-
67
- class Transformer(nn.Module):
68
- def __init__(
69
- self,
70
- embed_dim: int,
71
- dropout: float = 0.0,
72
- num_layers: int = 1,
73
- nheads: int = 16,
74
- activation: str = "gelu",
75
- ):
76
- super().__init__()
77
- encoder_layer = nn.TransformerEncoderLayer(
78
- d_model=embed_dim,
79
- dim_feedforward=embed_dim,
80
- dropout=dropout,
81
- activation=activation,
82
- batch_first=True,
83
- norm_first=False,
84
- bias=True,
85
- )
86
- self.T = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
87
-
88
- def forward(
89
- self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
90
- ) -> torch.Tensor:
91
- return self.T(x, mask=mask)
92
-
93
-
94
- class Net(nn.Module):
95
- def __init__(self, cfg: Config):
96
- super().__init__()
97
- self.cfg = cfg
98
- backbone_args = {
99
- "pretrained": self.cfg.pretrained,
100
- "num_classes": 0,
101
- "global_pool": "",
102
- "features_only": self.cfg.features_only,
103
- "in_chans": self.cfg.num_input_channels,
104
- }
105
- if self.cfg.backbone_img_size:
106
- # some models require specifying image size (e.g., coatnet)
107
- if "efficientvit" in self.cfg.backbone:
108
- backbone_args["img_size"] = self.cfg.image_height
109
- else:
110
- backbone_args["img_size"] = (
111
- self.cfg.image_height,
112
- self.cfg.image_width,
113
- )
114
- self.backbone = create_model(self.cfg.backbone, **backbone_args)
115
- # get feature dim by passing sample through net
116
- self.feature_dim = self.backbone(
117
- torch.randn(
118
- (
119
- 2,
120
- self.cfg.num_input_channels,
121
- self.cfg.image_height,
122
- self.cfg.image_width,
123
- )
124
- )
125
- ).size(
126
- -1 if "xcit" in self.cfg.backbone else 1
127
- ) # xcit models are channels-last
128
-
129
- self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1)
130
- self.pooling = get_pool_layer(self.cfg, dim=2)
131
-
132
- if isinstance(self.cfg.reduce_feature_dim, int):
133
- self.backbone = nn.Sequential(
134
- self.backbone,
135
- FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim),
136
- )
137
- self.feature_dim = self.cfg.reduce_feature_dim
138
-
139
- if self.cfg.add_lstm:
140
- self.pre_attn = BiLSTM(
141
- embed_dim=self.feature_dim,
142
- dropout=self.cfg.lstm_dropout or 0.0,
143
- num_layers=self.cfg.lstm_num_layers or 1,
144
- )
145
- elif self.cfg.add_transformer:
146
- self.pre_attn = Transformer(
147
- embed_dim=self.feature_dim,
148
- dropout=self.transformer_dropout or 0.0,
149
- num_layers=self.transformer_num_layers or 1,
150
- nheads=self.transformer_nheads or 16,
151
- activation=self.transformer_act or "gelu",
152
- )
153
- else:
154
- self.pre_attn = nn.Identity()
155
-
156
- self.attn = Attention(
157
- self.feature_dim,
158
- dropout=self.cfg.attn_dropout,
159
- version=self.cfg.attn_version or "v1",
160
- )
161
- self.dropout = nn.Dropout(p=self.cfg.dropout)
162
- self.linear = nn.Linear(self.feature_dim, self.cfg.num_classes)
163
-
164
- if self.cfg.load_pretrained_backbone:
165
- print(
166
- f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..."
167
- )
168
- weights = torch.load(
169
- self.cfg.load_pretrained_backbone,
170
- map_location=lambda storage, loc: storage,
171
- )["state_dict"]
172
- # Replace model prefix as this does not exist in Net
173
- weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()}
174
- # Get backbone only
175
- weights = {
176
- re.sub(r"^backbone.", "", k): v
177
- for k, v in weights.items()
178
- if "backbone" in k
179
- }
180
- self.backbone.load_state_dict(weights)
181
-
182
- self.criterion = None
183
-
184
- self.backbone_frozen = False
185
- if self.cfg.freeze_backbone:
186
- self.freeze_backbone()
187
-
188
- def normalize(self, x: torch.Tensor) -> torch.Tensor:
189
- if self.cfg.normalization == "-1_1":
190
- mini, maxi = (
191
- self.cfg.normalization_params["min"],
192
- self.cfg.normalization_params["max"],
193
- )
194
- x = x - mini
195
- x = x / (maxi - mini)
196
- x = x - 0.5
197
- x = x * 2.0
198
- elif self.cfg.normalization == "0_1":
199
- mini, maxi = (
200
- self.cfg.normalization_params["min"],
201
- self.cfg.normalization_params["max"],
202
- )
203
- x = x - mini
204
- x = x / (maxi - mini)
205
- elif self.cfg.normalization == "mean_sd":
206
- mean, sd = (
207
- self.cfg.normalization_params["mean"],
208
- self.cfg.normalization_params["sd"],
209
- )
210
- x = (x - mean) / sd
211
- elif self.cfg.normalization == "per_channel_mean_sd":
212
- mean, sd = (
213
- self.cfg.normalization_params["mean"],
214
- self.cfg.normalization_params["sd"],
215
- )
216
- assert len(mean) == len(sd) == x.size(1)
217
- mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0)
218
- for i in range(x.ndim - 2):
219
- mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1)
220
- x = (x - mean) / sd
221
- elif self.cfg.normalization == "none":
222
- x = x
223
- return x
224
-
225
- def forward(
226
- self,
227
- batch: Dict,
228
- return_loss: bool = False,
229
- return_features: bool = False,
230
- return_attn_scores: bool = False,
231
- ) -> Dict[str, torch.Tensor]:
232
- x = batch["x"]
233
- y = batch.get("y", None)
234
-
235
- if return_loss:
236
- assert y is not None
237
-
238
- b, n = x.shape[:2]
239
- x = rearrange(x, "b n c h w -> (b n) c h w")
240
- features = self.extract_features(x, normalize=True)
241
- features = rearrange(features, "(b n) d -> b n d", b=b, n=n)
242
- if isinstance(self.pre_attn, Transformer):
243
- features = self.pre_attn(features, mask=batch.get("mask", None))
244
- else:
245
- features = self.pre_attn(features)
246
- features, attn_scores = self.attn(features)
247
-
248
- if self.cfg.multisample_dropout:
249
- logits = torch.stack(
250
- [self.linear(self.dropout(features)) for _ in range(5)]
251
- ).mean(0)
252
- else:
253
- logits = self.linear(self.dropout(features))
254
-
255
- if self.cfg.model_activation_fn == "sigmoid":
256
- logits = logits.sigmoid()
257
- elif self.cfg.model_activation_fn == "softmax":
258
- logits = logits.softmax(dim=1)
259
-
260
- out = {"logits": logits}
261
- if return_features:
262
- out["features"] = features
263
- if return_attn_scores:
264
- out["attn_scores"] = attn_scores
265
- if return_loss:
266
- loss = self.criterion(out, batch)
267
- if isinstance(loss, dict):
268
- out.update(loss)
269
- else:
270
- out["loss"] = loss
271
-
272
- return out
273
-
274
- def extract_features(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor:
275
- x = self.normalize(x) if normalize else x
276
- return self.pooling(self.backbone(x))
277
-
278
- def freeze_backbone(self) -> None:
279
- for param in self.backbone.parameters():
280
- param.requires_grad = False
281
- self.backbone_frozen = True
282
-
283
- def set_criterion(self, loss: nn.Module) -> None:
284
- self.criterion = loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/models/classification/net2d.py DELETED
@@ -1,172 +0,0 @@
1
- """
2
- Simple model for 2D classification (or regression)
3
- Uses timm for backbones
4
- """
5
-
6
- import re
7
- import torch
8
- import torch.nn as nn
9
-
10
- from timm import create_model
11
- from typing import Dict
12
-
13
- from skp.configs.base import Config
14
- from skp.models.modules import FeatureReduction
15
- from skp.models.pooling import get_pool_layer
16
-
17
-
18
- class Net(nn.Module):
19
- def __init__(self, cfg: Config):
20
- super().__init__()
21
- self.cfg = cfg
22
- backbone_args = {
23
- "pretrained": self.cfg.pretrained,
24
- "num_classes": 0,
25
- "global_pool": "",
26
- "features_only": self.cfg.features_only,
27
- "in_chans": self.cfg.num_input_channels,
28
- }
29
- if self.cfg.backbone_img_size:
30
- # some models require specifying image size (e.g., coatnet)
31
- if "efficientvit" in self.cfg.backbone:
32
- backbone_args["img_size"] = self.cfg.image_height
33
- else:
34
- backbone_args["img_size"] = (
35
- self.cfg.image_height,
36
- self.cfg.image_width,
37
- )
38
- self.backbone = create_model(self.cfg.backbone, **backbone_args)
39
- # get feature dim by passing sample through net
40
- self.feature_dim = self.backbone(
41
- torch.randn(
42
- (
43
- 2,
44
- self.cfg.num_input_channels,
45
- self.cfg.image_height,
46
- self.cfg.image_width,
47
- )
48
- )
49
- ).size(
50
- -1 if "xcit" in self.cfg.backbone else 1
51
- ) # xcit models are channels-last
52
-
53
- self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1)
54
- self.pooling = get_pool_layer(self.cfg, dim=2)
55
-
56
- if isinstance(self.cfg.reduce_feature_dim, int):
57
- self.backbone = nn.Sequential(
58
- self.backbone,
59
- FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim),
60
- )
61
- self.feature_dim = self.cfg.reduce_feature_dim
62
-
63
- self.dropout = nn.Dropout(p=self.cfg.dropout)
64
- self.linear = nn.Linear(self.feature_dim, self.cfg.num_classes)
65
-
66
- if self.cfg.load_pretrained_backbone:
67
- print(
68
- f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..."
69
- )
70
- weights = torch.load(
71
- self.cfg.load_pretrained_backbone,
72
- map_location=lambda storage, loc: storage,
73
- )["state_dict"]
74
- # Replace model prefix as this does not exist in Net
75
- weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()}
76
- # Get backbone only
77
- weights = {
78
- re.sub(r"^backbone.", "", k): v
79
- for k, v in weights.items()
80
- if "backbone" in k
81
- }
82
- self.backbone.load_state_dict(weights)
83
-
84
- self.criterion = None
85
-
86
- self.backbone_frozen = False
87
- if self.cfg.freeze_backbone:
88
- self.freeze_backbone()
89
-
90
- def normalize(self, x: torch.Tensor) -> torch.Tensor:
91
- if self.cfg.normalization == "-1_1":
92
- mini, maxi = (
93
- self.cfg.normalization_params["min"],
94
- self.cfg.normalization_params["max"],
95
- )
96
- x = x - mini
97
- x = x / (maxi - mini)
98
- x = x - 0.5
99
- x = x * 2.0
100
- elif self.cfg.normalization == "0_1":
101
- mini, maxi = (
102
- self.cfg.normalization_params["min"],
103
- self.cfg.normalization_params["max"],
104
- )
105
- x = x - mini
106
- x = x / (maxi - mini)
107
- elif self.cfg.normalization == "mean_sd":
108
- mean, sd = (
109
- self.cfg.normalization_params["mean"],
110
- self.cfg.normalization_params["sd"],
111
- )
112
- x = (x - mean) / sd
113
- elif self.cfg.normalization == "per_channel_mean_sd":
114
- mean, sd = (
115
- self.cfg.normalization_params["mean"],
116
- self.cfg.normalization_params["sd"],
117
- )
118
- assert len(mean) == len(sd) == x.size(1)
119
- mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0)
120
- for i in range(x.ndim - 2):
121
- mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1)
122
- x = (x - mean) / sd
123
- elif self.cfg.normalization == "none":
124
- x = x
125
- return x
126
-
127
- def forward(
128
- self, batch: Dict, return_loss: bool = False, return_features: bool = False
129
- ) -> Dict[str, torch.Tensor]:
130
- x = batch["x"]
131
- y = batch.get("y", None)
132
-
133
- if return_loss:
134
- assert y is not None
135
-
136
- features = self.extract_features(x, normalize=True)
137
-
138
- if self.cfg.multisample_dropout:
139
- logits = torch.stack(
140
- [self.linear(self.dropout(features)) for _ in range(5)]
141
- ).mean(0)
142
- else:
143
- logits = self.linear(self.dropout(features))
144
-
145
- if self.cfg.model_activation_fn == "sigmoid":
146
- logits = logits.sigmoid()
147
- elif self.cfg.model_activation_fn == "softmax":
148
- logits = logits.softmax(dim=1)
149
-
150
- out = {"logits": logits}
151
- if return_features:
152
- out["features"] = features
153
- if return_loss:
154
- loss = self.criterion(out, batch)
155
- if isinstance(loss, dict):
156
- out.update(loss)
157
- else:
158
- out["loss"] = loss
159
-
160
- return out
161
-
162
- def extract_features(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor:
163
- x = self.normalize(x) if normalize else x
164
- return self.pooling(self.backbone(x))
165
-
166
- def freeze_backbone(self) -> None:
167
- for param in self.backbone.parameters():
168
- param.requires_grad = False
169
- self.backbone_frozen = True
170
-
171
- def set_criterion(self, loss: nn.Module) -> None:
172
- self.criterion = loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/models/classification/net2d_multihead.py DELETED
@@ -1,176 +0,0 @@
1
- """
2
- Simple model for 2D classification (or regression) with multiple heads
3
- Uses timm for backbones
4
- """
5
-
6
- import re
7
- import torch
8
- import torch.nn as nn
9
-
10
- from collections.abc import Sequence
11
- from timm import create_model
12
- from typing import Dict
13
-
14
- from skp.configs.base import Config
15
- from skp.models.modules import FeatureReduction
16
- from skp.models.pooling import get_pool_layer
17
-
18
-
19
- class Net(nn.Module):
20
- def __init__(self, cfg: Config):
21
- super().__init__()
22
- self.cfg = cfg
23
- assert (
24
- isinstance(self.cfg.num_classes, Sequence)
25
- and len(self.cfg.num_classes) == self.cfg.num_heads
26
- ), f"cfg.num_classes should be sequence of length {self.cfg.num_heads} corresponding to each head"
27
- backbone_args = {
28
- "pretrained": self.cfg.pretrained,
29
- "num_classes": 0,
30
- "global_pool": "",
31
- "features_only": self.cfg.features_only,
32
- "in_chans": self.cfg.num_input_channels,
33
- }
34
- if self.cfg.backbone_img_size:
35
- # some models require specifying image size (e.g., coatnet)
36
- if "efficientvit" in self.cfg.backbone:
37
- backbone_args["img_size"] = self.cfg.image_height
38
- else:
39
- backbone_args["img_size"] = (
40
- self.cfg.image_height,
41
- self.cfg.image_width,
42
- )
43
- self.backbone = create_model(self.cfg.backbone, **backbone_args)
44
- # get feature dim by passing sample through net
45
- self.feature_dim = self.backbone(
46
- torch.randn(
47
- (
48
- 2,
49
- self.cfg.num_input_channels,
50
- self.cfg.image_height,
51
- self.cfg.image_width,
52
- )
53
- )
54
- ).size(
55
- -1 if "xcit" in self.cfg.backbone else 1
56
- ) # xcit models are channels-last
57
-
58
- self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1)
59
- self.pooling = get_pool_layer(self.cfg, dim=2)
60
-
61
- if isinstance(self.cfg.reduce_feature_dim, int):
62
- self.backbone = nn.Sequential(
63
- self.backbone,
64
- FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim),
65
- )
66
- self.feature_dim = self.cfg.reduce_feature_dim
67
-
68
- self.dropout = nn.Dropout(p=self.cfg.dropout)
69
- self.linear = nn.ModuleList()
70
- for i in range(self.cfg.num_heads):
71
- self.linear.append(nn.Linear(self.feature_dim, self.cfg.num_classes[i]))
72
-
73
- if self.cfg.load_pretrained_backbone:
74
- print(
75
- f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..."
76
- )
77
- weights = torch.load(
78
- self.cfg.load_pretrained_backbone,
79
- map_location=lambda storage, loc: storage,
80
- )["state_dict"]
81
- # Replace model prefix as this does not exist in Net
82
- weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()}
83
- # Get backbone only
84
- weights = {
85
- re.sub(r"^backbone.", "", k): v
86
- for k, v in weights.items()
87
- if "backbone" in k
88
- }
89
- self.backbone.load_state_dict(weights)
90
-
91
- self.criterion = None
92
-
93
- self.backbone_frozen = False
94
- if self.cfg.freeze_backbone:
95
- self.freeze_backbone()
96
-
97
- def normalize(self, x: torch.Tensor) -> torch.Tensor:
98
- if self.cfg.normalization == "-1_1":
99
- mini, maxi = (
100
- self.cfg.normalization_params["min"],
101
- self.cfg.normalization_params["max"],
102
- )
103
- x = x - mini
104
- x = x / (maxi - mini)
105
- x = x - 0.5
106
- x = x * 2.0
107
- elif self.cfg.normalization == "0_1":
108
- mini, maxi = (
109
- self.cfg.normalization_params["min"],
110
- self.cfg.normalization_params["max"],
111
- )
112
- x = x - mini
113
- x = x / (maxi - mini)
114
- elif self.cfg.normalization == "mean_sd":
115
- mean, sd = (
116
- self.cfg.normalization_params["mean"],
117
- self.cfg.normalization_params["sd"],
118
- )
119
- x = (x - mean) / sd
120
- elif self.cfg.normalization == "per_channel_mean_sd":
121
- mean, sd = (
122
- self.cfg.normalization_params["mean"],
123
- self.cfg.normalization_params["sd"],
124
- )
125
- assert len(mean) == len(sd) == x.size(1)
126
- mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0)
127
- for i in range(x.ndim - 2):
128
- mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1)
129
- x = (x - mean) / sd
130
- elif self.cfg.normalization == "none":
131
- x = x
132
- return x
133
-
134
- def forward(
135
- self, batch: Dict, return_loss: bool = False, return_features: bool = False
136
- ) -> Dict[str, torch.Tensor]:
137
- x = batch["x"]
138
- y = batch.get("y", None)
139
-
140
- if return_loss:
141
- assert y is not None
142
-
143
- features = self.extract_features(x, normalize=True)
144
-
145
- out = {}
146
- for head_idx, each_head in enumerate(self.linear):
147
- if self.cfg.multisample_dropout:
148
- logits = torch.stack(
149
- [each_head(self.dropout(features)) for _ in range(5)]
150
- ).mean(0)
151
- else:
152
- logits = each_head(self.dropout(features))
153
- out[f"logits{head_idx}"] = logits
154
-
155
- if return_features:
156
- out["features"] = features
157
- if return_loss:
158
- loss = self.criterion(out, batch)
159
- if isinstance(loss, dict):
160
- out.update(loss)
161
- else:
162
- out["loss"] = loss
163
-
164
- return out
165
-
166
- def extract_features(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor:
167
- x = self.normalize(x) if normalize else x
168
- return self.pooling(self.backbone(x))
169
-
170
- def freeze_backbone(self) -> None:
171
- for param in self.backbone.parameters():
172
- param.requires_grad = False
173
- self.backbone_frozen = True
174
-
175
- def set_criterion(self, loss: nn.Module) -> None:
176
- self.criterion = loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/models/classification/net2d_multihead_var_embed.py DELETED
@@ -1,186 +0,0 @@
1
- """
2
- Simple model for 2D classification (or regression) with multiple heads
3
- Incorporates embedding of non-image features
4
- Uses timm for backbones
5
- """
6
-
7
- import re
8
- import torch
9
- import torch.nn as nn
10
-
11
- from collections.abc import Sequence
12
- from timm import create_model
13
- from typing import Dict
14
-
15
- from skp.configs.base import Config
16
- from skp.models.modules import FeatureReduction
17
- from skp.models.pooling import get_pool_layer
18
-
19
-
20
- class Net(nn.Module):
21
- def __init__(self, cfg: Config):
22
- super().__init__()
23
- self.cfg = cfg
24
- assert (
25
- isinstance(self.cfg.num_classes, Sequence)
26
- and len(self.cfg.num_classes) == self.cfg.num_heads
27
- ), f"cfg.num_classes should be sequence of length {self.cfg.num_heads} corresponding to each head"
28
- backbone_args = {
29
- "pretrained": self.cfg.pretrained,
30
- "num_classes": 0,
31
- "global_pool": "",
32
- "features_only": self.cfg.features_only,
33
- "in_chans": self.cfg.num_input_channels,
34
- }
35
- if self.cfg.backbone_img_size:
36
- # some models require specifying image size (e.g., coatnet)
37
- if "efficientvit" in self.cfg.backbone:
38
- backbone_args["img_size"] = self.cfg.image_height
39
- else:
40
- backbone_args["img_size"] = (
41
- self.cfg.image_height,
42
- self.cfg.image_width,
43
- )
44
- self.backbone = create_model(self.cfg.backbone, **backbone_args)
45
- # get feature dim by passing sample through net
46
- self.feature_dim = self.backbone(
47
- torch.randn(
48
- (
49
- 2,
50
- self.cfg.num_input_channels,
51
- self.cfg.image_height,
52
- self.cfg.image_width,
53
- )
54
- )
55
- ).size(
56
- -1 if "xcit" in self.cfg.backbone else 1
57
- ) # xcit models are channels-last
58
-
59
- self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1)
60
- self.pooling = get_pool_layer(self.cfg, dim=2)
61
-
62
- if isinstance(self.cfg.reduce_feature_dim, int):
63
- self.backbone = nn.Sequential(
64
- self.backbone,
65
- FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim),
66
- )
67
- self.feature_dim = self.cfg.reduce_feature_dim
68
-
69
- self.embed = nn.Embedding(self.cfg.embed_num_classes, self.cfg.embed_dim)
70
- # allows for interaction between elements of image feature vector and embedding
71
- self.mlp = nn.Linear(self.feature_dim + self.cfg.embed_dim, self.feature_dim)
72
- self.dropout = nn.Dropout(p=self.cfg.dropout)
73
- self.linear = nn.ModuleList()
74
- for i in range(self.cfg.num_heads):
75
- self.linear.append(nn.Linear(self.feature_dim, self.cfg.num_classes[i]))
76
-
77
-
78
- if self.cfg.load_pretrained_backbone:
79
- print(
80
- f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..."
81
- )
82
- weights = torch.load(
83
- self.cfg.load_pretrained_backbone,
84
- map_location=lambda storage, loc: storage,
85
- )["state_dict"]
86
- # Replace model prefix as this does not exist in Net
87
- weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()}
88
- # Get backbone only
89
- weights = {
90
- re.sub(r"^backbone.", "", k): v
91
- for k, v in weights.items()
92
- if "backbone" in k
93
- }
94
- self.backbone.load_state_dict(weights)
95
-
96
- self.criterion = None
97
-
98
- self.backbone_frozen = False
99
- if self.cfg.freeze_backbone:
100
- self.freeze_backbone()
101
-
102
- def normalize(self, x: torch.Tensor) -> torch.Tensor:
103
- if self.cfg.normalization == "-1_1":
104
- mini, maxi = (
105
- self.cfg.normalization_params["min"],
106
- self.cfg.normalization_params["max"],
107
- )
108
- x = x - mini
109
- x = x / (maxi - mini)
110
- x = x - 0.5
111
- x = x * 2.0
112
- elif self.cfg.normalization == "0_1":
113
- mini, maxi = (
114
- self.cfg.normalization_params["min"],
115
- self.cfg.normalization_params["max"],
116
- )
117
- x = x - mini
118
- x = x / (maxi - mini)
119
- elif self.cfg.normalization == "mean_sd":
120
- mean, sd = (
121
- self.cfg.normalization_params["mean"],
122
- self.cfg.normalization_params["sd"],
123
- )
124
- x = (x - mean) / sd
125
- elif self.cfg.normalization == "per_channel_mean_sd":
126
- mean, sd = (
127
- self.cfg.normalization_params["mean"],
128
- self.cfg.normalization_params["sd"],
129
- )
130
- assert len(mean) == len(sd) == x.size(1)
131
- mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0)
132
- for i in range(x.ndim - 2):
133
- mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1)
134
- x = (x - mean) / sd
135
- elif self.cfg.normalization == "none":
136
- x = x
137
- return x
138
-
139
- def forward(
140
- self, batch: Dict, return_loss: bool = False, return_features: bool = False
141
- ) -> Dict[str, torch.Tensor]:
142
- x = batch["x"]
143
- y = batch.get("y", None)
144
- var = batch["var"]
145
-
146
- if return_loss:
147
- assert y is not None
148
-
149
- features = self.extract_features(x, var, normalize=True)
150
-
151
- out = {}
152
- for head_idx, each_head in enumerate(self.linear):
153
- if self.cfg.multisample_dropout:
154
- logits = torch.stack(
155
- [each_head(self.dropout(features)) for _ in range(5)]
156
- ).mean(0)
157
- else:
158
- logits = each_head(self.dropout(features))
159
- out[f"logits{head_idx}"] = logits
160
-
161
- if return_features:
162
- out["features"] = features
163
- if return_loss:
164
- loss = self.criterion(out, batch)
165
- if isinstance(loss, dict):
166
- out.update(loss)
167
- else:
168
- out["loss"] = loss
169
-
170
- return out
171
-
172
- def extract_features(self, x: torch.Tensor, var: torch.Tensor, normalize: bool = True) -> torch.Tensor:
173
- x = self.normalize(x) if normalize else x
174
- var = self.embed(var)
175
- feat = self.pooling(self.backbone(x))
176
- feat = torch.cat([feat, var], dim=1)
177
- feat = self.mlp(feat)
178
- return feat
179
-
180
- def freeze_backbone(self) -> None:
181
- for param in self.backbone.parameters():
182
- param.requires_grad = False
183
- self.backbone_frozen = True
184
-
185
- def set_criterion(self, loss: nn.Module) -> None:
186
- self.criterion = loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/models/classification/net2d_var_embed.py DELETED
@@ -1,178 +0,0 @@
1
- """
2
- Simple model for 2D classification (or regression)
3
- Incorporates embedding of non-image features
4
- Uses timm for backbones
5
- """
6
-
7
- import re
8
- import torch
9
- import torch.nn as nn
10
-
11
- from timm import create_model
12
- from typing import Dict
13
-
14
- from skp.configs.base import Config
15
- from skp.models.modules import FeatureReduction
16
- from skp.models.pooling import get_pool_layer
17
-
18
-
19
- class Net(nn.Module):
20
- def __init__(self, cfg: Config):
21
- super().__init__()
22
- self.cfg = cfg
23
- backbone_args = {
24
- "pretrained": self.cfg.pretrained,
25
- "num_classes": 0,
26
- "global_pool": "",
27
- "features_only": self.cfg.features_only,
28
- "in_chans": self.cfg.num_input_channels,
29
- }
30
- if self.cfg.backbone_img_size:
31
- # some models require specifying image size (e.g., coatnet)
32
- if "efficientvit" in self.cfg.backbone:
33
- backbone_args["img_size"] = self.cfg.image_height
34
- else:
35
- backbone_args["img_size"] = (
36
- self.cfg.image_height,
37
- self.cfg.image_width,
38
- )
39
- self.backbone = create_model(self.cfg.backbone, **backbone_args)
40
- # get feature dim by passing sample through net
41
- self.feature_dim = self.backbone(
42
- torch.randn(
43
- (
44
- 2,
45
- self.cfg.num_input_channels,
46
- self.cfg.image_height,
47
- self.cfg.image_width,
48
- )
49
- )
50
- ).size(
51
- -1 if "xcit" in self.cfg.backbone else 1
52
- ) # xcit models are channels-last
53
-
54
- self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1)
55
- self.pooling = get_pool_layer(self.cfg, dim=2)
56
-
57
- if isinstance(self.cfg.reduce_feature_dim, int):
58
- self.backbone = nn.Sequential(
59
- self.backbone,
60
- FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim),
61
- )
62
- self.feature_dim = self.cfg.reduce_feature_dim
63
-
64
- self.embed = nn.Embedding(self.cfg.embed_num_classes, self.cfg.embed_dim)
65
- # allows for interaction between elements of image feature vector and embedding
66
- self.mlp = nn.Linear(self.feature_dim + self.cfg.embed_dim, self.feature_dim)
67
- self.dropout = nn.Dropout(p=self.cfg.dropout)
68
- self.linear = nn.Linear(self.feature_dim, self.cfg.num_classes)
69
-
70
- if self.cfg.load_pretrained_backbone:
71
- print(
72
- f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..."
73
- )
74
- weights = torch.load(
75
- self.cfg.load_pretrained_backbone,
76
- map_location=lambda storage, loc: storage,
77
- )["state_dict"]
78
- # Replace model prefix as this does not exist in Net
79
- weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()}
80
- # Get backbone only
81
- weights = {
82
- re.sub(r"^backbone.", "", k): v
83
- for k, v in weights.items()
84
- if "backbone" in k
85
- }
86
- self.backbone.load_state_dict(weights)
87
-
88
- self.criterion = None
89
-
90
- self.backbone_frozen = False
91
- if self.cfg.freeze_backbone:
92
- self.freeze_backbone()
93
-
94
- def normalize(self, x: torch.Tensor) -> torch.Tensor:
95
- if self.cfg.normalization == "-1_1":
96
- mini, maxi = (
97
- self.cfg.normalization_params["min"],
98
- self.cfg.normalization_params["max"],
99
- )
100
- x = x - mini
101
- x = x / (maxi - mini)
102
- x = x - 0.5
103
- x = x * 2.0
104
- elif self.cfg.normalization == "0_1":
105
- mini, maxi = (
106
- self.cfg.normalization_params["min"],
107
- self.cfg.normalization_params["max"],
108
- )
109
- x = x - mini
110
- x = x / (maxi - mini)
111
- elif self.cfg.normalization == "mean_sd":
112
- mean, sd = (
113
- self.cfg.normalization_params["mean"],
114
- self.cfg.normalization_params["sd"],
115
- )
116
- x = (x - mean) / sd
117
- elif self.cfg.normalization == "per_channel_mean_sd":
118
- mean, sd = (
119
- self.cfg.normalization_params["mean"],
120
- self.cfg.normalization_params["sd"],
121
- )
122
- assert len(mean) == len(sd) == x.size(1)
123
- mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0)
124
- for i in range(x.ndim - 2):
125
- mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1)
126
- x = (x - mean) / sd
127
- elif self.cfg.normalization == "none":
128
- x = x
129
- return x
130
-
131
- def forward(
132
- self, batch: Dict, return_loss: bool = False, return_features: bool = False
133
- ) -> Dict[str, torch.Tensor]:
134
- x = batch["x"]
135
- y = batch.get("y", None)
136
- var = batch["var"]
137
-
138
- if return_loss:
139
- assert y is not None
140
-
141
- features = self.extract_features(x, var, normalize=True)
142
-
143
- if self.cfg.multisample_dropout:
144
- logits = torch.stack(
145
- [self.linear(self.dropout(features)) for _ in range(5)]
146
- ).mean(0)
147
- else:
148
- logits = self.linear(self.dropout(features))
149
-
150
- out = {"logits": logits}
151
- if return_features:
152
- out["features"] = features
153
- if return_loss:
154
- loss = self.criterion(out, batch)
155
- if isinstance(loss, dict):
156
- out.update(loss)
157
- else:
158
- out["loss"] = loss
159
-
160
- return out
161
-
162
- def extract_features(
163
- self, x: torch.Tensor, var: torch.Tensor, normalize: bool = True
164
- ) -> torch.Tensor:
165
- x = self.normalize(x) if normalize else x
166
- var = self.embed(var)
167
- feat = self.pooling(self.backbone(x))
168
- feat = torch.cat([feat, var], dim=1)
169
- feat = self.mlp(feat)
170
- return feat
171
-
172
- def freeze_backbone(self) -> None:
173
- for param in self.backbone.parameters():
174
- param.requires_grad = False
175
- self.backbone_frozen = True
176
-
177
- def set_criterion(self, loss: nn.Module) -> None:
178
- self.criterion = loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/models/modules.py DELETED
@@ -1,32 +0,0 @@
1
- """
2
- Contains commonly used neural net modules.
3
- """
4
-
5
- import math
6
- import torch
7
- import torch.nn as nn
8
-
9
-
10
- class FeatureReduction(nn.Module):
11
- """
12
- Reduce feature dimensionality
13
- Intended use is after the last layer of the neural net backbone, before pooling
14
- Grouped convolution is used to reduce # of extra parameters
15
- """
16
-
17
- def __init__(self, feature_dim: int, reduce_feature_dim: int):
18
- super().__init__()
19
- groups = math.gcd(feature_dim, reduce_feature_dim)
20
- self.reduce = nn.Conv2d(
21
- feature_dim,
22
- reduce_feature_dim,
23
- groups=groups,
24
- kernel_size=1,
25
- stride=1,
26
- bias=False,
27
- )
28
- self.bn = nn.BatchNorm2d(reduce_feature_dim)
29
- self.act = nn.ReLU()
30
-
31
- def forward(self, x: torch.Tensor) -> torch.Tensor:
32
- return self.act(self.bn(self.reduce(x)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/models/pooling.py DELETED
@@ -1,150 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from timm.layers import SelectAdaptivePool2d
6
-
7
- from skp.configs.base import Config
8
-
9
-
10
- class GeM(nn.Module):
11
- def __init__(
12
- self, p: int = 3, eps: float = 1e-6, dim: int = 2, flatten: bool = True
13
- ):
14
- super().__init__()
15
- self.p = nn.Parameter(torch.ones(1) * p)
16
- self.eps = eps
17
- assert dim in {2, 3}, f"dim must be one of [2, 3], not {dim}"
18
- self.dim = dim
19
- if self.dim == 2:
20
- self.func = F.adaptive_avg_pool2d
21
- elif self.dim == 3:
22
- self.func = F.adaptive_avg_pool3d
23
- self.flatten = nn.Flatten(1) if flatten else nn.Identity()
24
-
25
- def forward(self, x: torch.Tensor) -> torch.Tensor:
26
- # assumes x.shape is (n, c, [t], h, w)
27
- x = self.func(x.clamp(min=self.eps).pow(self.p), output_size=1).pow(
28
- 1.0 / self.p
29
- )
30
- return self.flatten(x)
31
-
32
-
33
- def adaptive_avgmax_pool3d(x: torch.Tensor, output_size: int = 1):
34
- x_avg = F.adaptive_avg_pool3d(x, output_size)
35
- x_max = F.adaptive_max_pool3d(x, output_size)
36
- return 0.5 * (x_avg + x_max)
37
-
38
-
39
- def adaptive_catavgmax_pool3d(x: torch.Tensor, output_size: int = 1):
40
- x_avg = F.adaptive_avg_pool3d(x, output_size)
41
- x_max = F.adaptive_max_pool3d(x, output_size)
42
- return torch.cat((x_avg, x_max), 1)
43
-
44
-
45
- def select_adaptive_pool3d(x: torch.Tensor, pool_type: str, output_size: int = 1) -> torch.Tensor:
46
- """Selectable global pooling function with dynamic input kernel size"""
47
- if pool_type == "avg":
48
- x = F.adaptive_avg_pool3d(x, output_size)
49
- elif pool_type == "avgmax":
50
- x = adaptive_avgmax_pool3d(x, output_size)
51
- elif pool_type == "catavgmax":
52
- x = adaptive_catavgmax_pool3d(x, output_size)
53
- elif pool_type == "max":
54
- x = F.adaptive_max_pool3d(x, output_size)
55
- else:
56
- assert False, "Invalid pool type: %s" % pool_type
57
- return x
58
-
59
-
60
- class FastAdaptiveAvgPool3d(nn.Module):
61
- def __init__(self, flatten: bool = False):
62
- super(FastAdaptiveAvgPool3d, self).__init__()
63
- self.flatten = flatten
64
-
65
- def forward(self, x: torch.Tensor) -> torch.Tensor:
66
- return x.mean((2, 3, 4), keepdim=not self.flatten)
67
-
68
-
69
- class AdaptiveAvgMaxPool3d(nn.Module):
70
- def __init__(self, output_size: int = 1):
71
- super(AdaptiveAvgMaxPool3d, self).__init__()
72
- self.output_size = output_size
73
-
74
- def forward(self, x: torch.Tensor) -> torch.Tensor:
75
- return adaptive_avgmax_pool3d(x, self.output_size)
76
-
77
-
78
- class AdaptiveCatAvgMaxPool3d(nn.Module):
79
- def __init__(self, output_size: int = 1):
80
- super(AdaptiveCatAvgMaxPool3d, self).__init__()
81
- self.output_size = output_size
82
-
83
- def forward(self, x: torch.Tensor) -> torch.Tensor:
84
- return adaptive_catavgmax_pool3d(x, self.output_size)
85
-
86
-
87
- class SelectAdaptivePool3d(nn.Module):
88
- """Selectable global pooling layer with dynamic input kernel size"""
89
-
90
- def __init__(self, output_size: int = 1, pool_type: str = "fast", flatten: bool = False):
91
- super(SelectAdaptivePool3d, self).__init__()
92
- self.pool_type = (
93
- pool_type or ""
94
- ) # convert other falsy values to empty string for consistent TS typing
95
- self.flatten = nn.Flatten(1) if flatten else nn.Identity()
96
- if pool_type == "":
97
- self.pool = nn.Identity() # pass through
98
- elif pool_type == "fast":
99
- assert output_size == 1
100
- self.pool = FastAdaptiveAvgPool3d(flatten)
101
- self.flatten = nn.Identity()
102
- elif pool_type == "avg":
103
- self.pool = nn.AdaptiveAvgPool3d(output_size)
104
- elif pool_type == "avgmax":
105
- self.pool = AdaptiveAvgMaxPool3d(output_size)
106
- elif pool_type == "catavgmax":
107
- self.pool = AdaptiveCatAvgMaxPool3d(output_size)
108
- elif pool_type == "max":
109
- self.pool = nn.AdaptiveMaxPool3d(output_size)
110
- else:
111
- assert False, "Invalid pool type: %s" % pool_type
112
-
113
- def is_identity(self) -> bool:
114
- return not self.pool_type
115
-
116
- def forward(self, x: torch.Tensor) -> torch.Tensor:
117
- x = self.pool(x)
118
- x = self.flatten(x)
119
- return x
120
-
121
- def __repr__(self):
122
- return (
123
- self.__class__.__name__
124
- + " ("
125
- + "pool_type="
126
- + self.pool_type
127
- + ", flatten="
128
- + str(self.flatten)
129
- + ")"
130
- )
131
-
132
-
133
- def get_pool_layer(cfg: Config, dim: int) -> nn.Module:
134
- assert cfg.pool in [
135
- "avg",
136
- "max",
137
- "fast",
138
- "avgmax",
139
- "catavgmax",
140
- "gem",
141
- ""
142
- ], f"{cfg.pool} is not a valid pooling layer"
143
- params = cfg.pool_params or {}
144
- if cfg.pool == "gem":
145
- return GeM(**params, dim=dim)
146
- else:
147
- if dim == 2:
148
- return SelectAdaptivePool2d(pool_type=cfg.pool, flatten=True)
149
- elif dim == 3:
150
- return SelectAdaptivePool3d(pool_type=cfg.pool, flatten=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
skp/utils.py DELETED
@@ -1,49 +0,0 @@
1
- import re
2
- import torch
3
-
4
- from skp.configs import Config
5
- from importlib import import_module
6
- from typing import Dict, Optional, Sequence
7
-
8
-
9
- def load_weights_from_path(path: str) -> Dict[str, torch.Tensor]:
10
- w = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)[
11
- "state_dict"
12
- ]
13
- w = {
14
- re.sub(r"^model.", "", k): v
15
- for k, v in w.items()
16
- if k.startswith("model.") and "criterion" not in k
17
- }
18
- return w
19
-
20
-
21
- def load_model_from_config(
22
- cfg: Config,
23
- weights_path: Optional[str] = None,
24
- device: str = "cpu",
25
- eval_mode: bool = True,
26
- ) -> torch.nn.Module:
27
- model = import_module(f"skp.models.{cfg.model}").Net(cfg)
28
- if weights_path:
29
- weights = load_weights_from_path(weights_path)
30
- model.load_state_dict(weights)
31
- model = model.to(device).train(mode=not eval_mode)
32
- return model
33
-
34
-
35
- def load_kfold_ensemble_as_list(
36
- cfg: Config,
37
- weights_paths: Sequence[str],
38
- device: str = "cpu",
39
- eval_mode: bool = True,
40
- ) -> torch.nn.ModuleList:
41
- # multiple folds for the same model
42
- # does not work for ensembling different types of models
43
- # assumes that trained weights are available
44
- # otherwise why would you load multiple of the same model randomly initialized
45
- model_list = torch.nn.ModuleList()
46
- for each_weight in weights_paths:
47
- model = load_model_from_config(cfg, each_weight, device, eval_mode)
48
- model_list.append(model)
49
- return model_list