wangerniu commited on
Commit
c9b5796
·
1 Parent(s): 1e5420b

添加必要文件

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. conf/maplocnet.yaml +105 -0
  2. conf/maplocnetsingle-101.yaml +105 -0
  3. conf/maplocnetsingle.yaml +105 -0
  4. conf/maplocnetsingle0526.yaml +105 -0
  5. conf/maplocnetsingleunet.yaml +105 -0
  6. conf/maplocnetsinglhub_DDRNet.yaml +107 -0
  7. conf/maplocnetsinglhub_FPN-resnet18WeightedEmbedding.yaml +112 -0
  8. conf/maplocnetsinglhub_FPN-resnet34LightWeightedEmbedding.yaml +112 -0
  9. conf/maplocnetsinglhub_FPN-resnet34WeightedEmbedding.yaml +112 -0
  10. conf/maplocnetsinglhub_FPN-resnet50.yaml +111 -0
  11. conf/maplocnetsinglhub_FPN-resnet50WeightedEmbedding.yaml +112 -0
  12. conf/maplocnetsinglhub_FPN.yaml +107 -0
  13. conf/maplocnetsinglhub_FPN_Mobileone.yaml +107 -0
  14. conf/maplocnetsinglhub_PSP.yaml +107 -0
  15. conf/orienternet.yaml +103 -0
  16. dataset/UAV/dataset.py +116 -0
  17. dataset/__init__.py +4 -0
  18. dataset/dataset.py +109 -0
  19. dataset/image.py +140 -0
  20. dataset/torch.py +111 -0
  21. evaluation/kitti.py +89 -0
  22. evaluation/mapillary.py +0 -0
  23. evaluation/run.py +252 -0
  24. evaluation/utils.py +40 -0
  25. evaluation/viz.py +178 -0
  26. feature_extractor_models/__init__.py +82 -0
  27. feature_extractor_models/__version__.py +3 -0
  28. feature_extractor_models/base/__init__.py +13 -0
  29. feature_extractor_models/base/heads.py +34 -0
  30. feature_extractor_models/base/hub_mixin.py +154 -0
  31. feature_extractor_models/base/initialization.py +26 -0
  32. feature_extractor_models/base/model.py +71 -0
  33. feature_extractor_models/base/modules.py +131 -0
  34. feature_extractor_models/decoders/__init__.py +0 -0
  35. feature_extractor_models/decoders/deeplabv3/__init__.py +3 -0
  36. feature_extractor_models/decoders/deeplabv3/decoder.py +220 -0
  37. feature_extractor_models/decoders/deeplabv3/model.py +178 -0
  38. feature_extractor_models/decoders/fpn/__init__.py +3 -0
  39. feature_extractor_models/decoders/fpn/decoder.py +133 -0
  40. feature_extractor_models/decoders/fpn/model.py +107 -0
  41. feature_extractor_models/decoders/lightfpn/__init__.py +3 -0
  42. feature_extractor_models/decoders/lightfpn/decoder.py +144 -0
  43. feature_extractor_models/decoders/lightfpn/model.py +107 -0
  44. feature_extractor_models/decoders/linknet/__init__.py +3 -0
  45. feature_extractor_models/decoders/linknet/decoder.py +82 -0
  46. feature_extractor_models/decoders/linknet/model.py +98 -0
  47. feature_extractor_models/decoders/manet/__init__.py +3 -0
  48. feature_extractor_models/decoders/manet/decoder.py +187 -0
  49. feature_extractor_models/decoders/manet/model.py +102 -0
  50. feature_extractor_models/decoders/pan/__init__.py +3 -0
conf/maplocnet.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: true
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v2
49
+ backbone:
50
+ encoder: resnet50
51
+ pretrained: true
52
+ output_dim: 8
53
+ num_downsample: null
54
+ remove_stride_from_first_conv: false
55
+ name: maplocnet
56
+ matching_dim: 8
57
+ z_max: 32
58
+ x_max: 32
59
+ pixel_per_meter: 1
60
+ num_scale_bins: 33
61
+ num_rotations: 64
62
+ map_encoder:
63
+ embedding_dim: 16
64
+ output_dim: 8
65
+ num_classes:
66
+ areas: 7
67
+ ways: 10
68
+ nodes: 33
69
+ backbone:
70
+ encoder: vgg19
71
+ pretrained: false
72
+ output_scales:
73
+ - 0
74
+ num_downsample: 3
75
+ decoder:
76
+ - 128
77
+ - 64
78
+ - 64
79
+ padding: replicate
80
+ unary_prior: false
81
+ bev_net:
82
+ num_blocks: 4
83
+ latent_dim: 128
84
+ output_dim: 8
85
+ confidence: true
86
+ experiment:
87
+ name: maplocanet_0526_re
88
+ gpus: 2
89
+ seed: 0
90
+ training:
91
+ lr: 0.0001
92
+ lr_scheduler: null
93
+ finetune_from_checkpoint: null
94
+ trainer:
95
+ val_check_interval: 1000
96
+ log_every_n_steps: 100
97
+ # limit_val_batches: 1000
98
+ max_steps: 200000
99
+ devices: ${experiment.gpus}
100
+ checkpointing:
101
+ monitor: "loss/total/val"
102
+ save_top_k: 10
103
+ mode: min
104
+
105
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsingle-101.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v2
49
+ backbone:
50
+ encoder: resnet101
51
+ pretrained: true
52
+ output_dim: 8
53
+ num_downsample: null
54
+ remove_stride_from_first_conv: false
55
+ name: maplocnet
56
+ matching_dim: 8
57
+ z_max: 32
58
+ x_max: 32
59
+ pixel_per_meter: 1
60
+ num_scale_bins: 33
61
+ num_rotations: 64
62
+ map_encoder:
63
+ embedding_dim: 48
64
+ output_dim: 8
65
+ num_classes:
66
+ areas: 7
67
+ ways: 10
68
+ nodes: 33
69
+ backbone:
70
+ encoder: vgg19
71
+ pretrained: false
72
+ output_scales:
73
+ - 0
74
+ num_downsample: 3
75
+ decoder:
76
+ - 128
77
+ - 64
78
+ - 64
79
+ padding: replicate
80
+ unary_prior: false
81
+ bev_net:
82
+ num_blocks: 4
83
+ latent_dim: 128
84
+ output_dim: 8
85
+ confidence: true
86
+ experiment:
87
+ name: maplocanet_523_single_A100_no_mutil_scale_augmentation_resnet101_nosingle
88
+ gpus: 2
89
+ seed: 0
90
+ training:
91
+ lr: 0.0001
92
+ lr_scheduler: null
93
+ finetune_from_checkpoint: null
94
+ trainer:
95
+ val_check_interval: 1000
96
+ log_every_n_steps: 100
97
+ # limit_val_batches: 1000
98
+ max_steps: 200000
99
+ devices: ${experiment.gpus}
100
+ checkpointing:
101
+ monitor: "val/xy_recall_1m"
102
+ save_top_k: 10
103
+ mode: min
104
+
105
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsingle.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v2
49
+ backbone:
50
+ encoder: resnet101
51
+ pretrained: true
52
+ output_dim: 8
53
+ num_downsample: null
54
+ remove_stride_from_first_conv: false
55
+ name: maplocnet
56
+ matching_dim: 8
57
+ z_max: 32
58
+ x_max: 32
59
+ pixel_per_meter: 1
60
+ num_scale_bins: 33
61
+ num_rotations: 64
62
+ map_encoder:
63
+ embedding_dim: 48
64
+ output_dim: 8
65
+ num_classes:
66
+ all: 50
67
+ # ways: 10
68
+ # nodes: 33
69
+ backbone:
70
+ encoder: vgg19
71
+ pretrained: false
72
+ output_scales:
73
+ - 0
74
+ num_downsample: 3
75
+ decoder:
76
+ - 128
77
+ - 64
78
+ - 64
79
+ padding: replicate
80
+ unary_prior: false
81
+ bev_net:
82
+ num_blocks: 4
83
+ latent_dim: 128
84
+ output_dim: 8
85
+ confidence: true
86
+ experiment:
87
+ name: maplocanet_523_single_A100_no_mutil_scale_augmentation_resnet101_2
88
+ gpus: 2
89
+ seed: 0
90
+ training:
91
+ lr: 0.0001
92
+ lr_scheduler: null
93
+ finetune_from_checkpoint: null
94
+ trainer:
95
+ val_check_interval: 1000
96
+ log_every_n_steps: 100
97
+ # limit_val_batches: 1000
98
+ max_steps: 200000
99
+ devices: ${experiment.gpus}
100
+ checkpointing:
101
+ monitor: "val/xy_recall_1m"
102
+ save_top_k: 10
103
+ mode: min
104
+
105
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsingle0526.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: false
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v2
49
+ backbone:
50
+ encoder: resnet50
51
+ pretrained: true
52
+ output_dim: 8
53
+ num_downsample: null
54
+ remove_stride_from_first_conv: false
55
+ name: maplocnet
56
+ matching_dim: 8
57
+ z_max: 32
58
+ x_max: 32
59
+ pixel_per_meter: 1
60
+ num_scale_bins: 33
61
+ num_rotations: 64
62
+ map_encoder:
63
+ embedding_dim: 48
64
+ output_dim: 8
65
+ num_classes:
66
+ all: 50
67
+ # ways: 10
68
+ # nodes: 33
69
+ backbone:
70
+ encoder: vgg19
71
+ pretrained: false
72
+ output_scales:
73
+ - 0
74
+ num_downsample: 3
75
+ decoder:
76
+ - 128
77
+ - 64
78
+ - 64
79
+ padding: replicate
80
+ unary_prior: false
81
+ bev_net:
82
+ num_blocks: 4
83
+ latent_dim: 128
84
+ output_dim: 8
85
+ confidence: true
86
+ experiment:
87
+ name: maplocanet_523_single_A100_no_mutil_scale
88
+ gpus: 2
89
+ seed: 0
90
+ training:
91
+ lr: 0.0001
92
+ lr_scheduler: null
93
+ finetune_from_checkpoint: null
94
+ trainer:
95
+ val_check_interval: 1000
96
+ log_every_n_steps: 100
97
+ # limit_val_batches: 1000
98
+ max_steps: 200000
99
+ devices: ${experiment.gpus}
100
+ checkpointing:
101
+ monitor: "val/xy_recall_1m"
102
+ save_top_k: 10
103
+ mode: min
104
+
105
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsingleunet.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v3
49
+ backbone:
50
+ # encoder: resnet101
51
+ # pretrained: true
52
+ output_dim: 8
53
+ # num_downsample: null
54
+ # remove_stride_from_first_conv: false
55
+ name: maplocnet
56
+ matching_dim: 8
57
+ z_max: 32
58
+ x_max: 32
59
+ pixel_per_meter: 1
60
+ num_scale_bins: 33
61
+ num_rotations: 64
62
+ map_encoder:
63
+ embedding_dim: 48
64
+ output_dim: 8
65
+ num_classes:
66
+ all: 50
67
+ # ways: 10
68
+ # nodes: 33
69
+ backbone:
70
+ encoder: vgg19
71
+ pretrained: false
72
+ output_scales:
73
+ - 0
74
+ num_downsample: 3
75
+ decoder:
76
+ - 128
77
+ - 64
78
+ - 64
79
+ padding: replicate
80
+ unary_prior: false
81
+ bev_net:
82
+ num_blocks: 4
83
+ latent_dim: 128
84
+ output_dim: 8
85
+ confidence: true
86
+ experiment:
87
+ name: maplocanet_601_unet
88
+ gpus: 2
89
+ seed: 0
90
+ training:
91
+ lr: 0.0001
92
+ lr_scheduler: null
93
+ finetune_from_checkpoint: null
94
+ trainer:
95
+ val_check_interval: 1000
96
+ log_every_n_steps: 100
97
+ # limit_val_batches: 1000
98
+ max_steps: 200000
99
+ devices: ${experiment.gpus}
100
+ checkpointing:
101
+ monitor: "val/xy_recall_1m"
102
+ save_top_k: 10
103
+ mode: min
104
+
105
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_DDRNet.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v5
49
+ architecture: DDRNet23s
50
+ backbone:
51
+ # encoder: resnet50
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ num_classes:
68
+ all: 50
69
+ # ways: 10
70
+ # nodes: 33
71
+ backbone:
72
+ encoder: vgg19
73
+ pretrained: false
74
+ output_scales:
75
+ - 0
76
+ num_downsample: 3
77
+ decoder:
78
+ - 128
79
+ - 64
80
+ - 64
81
+ padding: replicate
82
+ unary_prior: false
83
+ bev_net:
84
+ num_blocks: 4
85
+ latent_dim: 128
86
+ output_dim: 8
87
+ confidence: true
88
+ experiment:
89
+ name: maplocanet_602_hub_DDRnet
90
+ gpus: 2
91
+ seed: 0
92
+ training:
93
+ lr: 0.0001
94
+ lr_scheduler: null
95
+ finetune_from_checkpoint: null
96
+ trainer:
97
+ val_check_interval: 1000
98
+ log_every_n_steps: 100
99
+ # limit_val_batches: 1000
100
+ max_steps: 200000
101
+ devices: ${experiment.gpus}
102
+ checkpointing:
103
+ monitor: "val/xy_recall_1m"
104
+ save_top_k: 5
105
+ mode: max
106
+
107
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_FPN-resnet18WeightedEmbedding.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: false
36
+ flip: false
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v4
49
+ architecture: FPN
50
+ backbone:
51
+ encoder: resnet18
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ weighted_embedding: ImprovedAttentionEmbedding
68
+ num_classes:
69
+ all: 50
70
+ # ways: 10
71
+ # nodes: 33
72
+ backbone:
73
+ encoder: vgg19
74
+ pretrained: false
75
+ output_scales:
76
+ - 0
77
+ num_downsample: 3
78
+ decoder:
79
+ - 128
80
+ - 64
81
+ - 64
82
+ padding: replicate
83
+ unary_prior: false
84
+ bev_net:
85
+ num_blocks: 4
86
+ latent_dim: 128
87
+ output_dim: 8
88
+ confidence: true
89
+ experiment:
90
+ name: maplocanet_602_hub_FPN_norelu_resnet18_ImprovedAttentionEmbedding
91
+ gpus: 5
92
+ seed: 42
93
+ training:
94
+ lr: 0.0001
95
+ lr_scheduler:
96
+ name: StepLR
97
+ args:
98
+ step_size: 10
99
+ gamma: 0.1
100
+ finetune_from_checkpoint: null
101
+ trainer:
102
+ val_check_interval: 1000
103
+ log_every_n_steps: 100
104
+ # limit_val_batches: 1000
105
+ max_steps: 300000
106
+ devices: ${experiment.gpus}
107
+ checkpointing:
108
+ monitor: "val/xy_recall_1m"
109
+ save_top_k: 5
110
+ mode: max
111
+
112
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_FPN-resnet34LightWeightedEmbedding.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: false
36
+ flip: false
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v4
49
+ architecture: LightFPN
50
+ backbone:
51
+ encoder: resnet34
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ weighted_embedding: ImprovedAttentionEmbedding
68
+ num_classes:
69
+ all: 50
70
+ # ways: 10
71
+ # nodes: 33
72
+ backbone:
73
+ encoder: vgg19
74
+ pretrained: false
75
+ output_scales:
76
+ - 0
77
+ num_downsample: 3
78
+ decoder:
79
+ - 128
80
+ - 64
81
+ - 64
82
+ padding: replicate
83
+ unary_prior: false
84
+ bev_net:
85
+ num_blocks: 4
86
+ latent_dim: 128
87
+ output_dim: 8
88
+ confidence: true
89
+ experiment:
90
+ name: maplocanet_602_hub_FPN_norelu_resnet34Light_ImprovedAttentionEmbedding
91
+ gpus: 5
92
+ seed: 42
93
+ training:
94
+ lr: 0.0001
95
+ lr_scheduler:
96
+ name: StepLR
97
+ args:
98
+ step_size: 10
99
+ gamma: 0.1
100
+ finetune_from_checkpoint: null
101
+ trainer:
102
+ val_check_interval: 1000
103
+ log_every_n_steps: 100
104
+ # limit_val_batches: 1000
105
+ max_steps: 300000
106
+ devices: ${experiment.gpus}
107
+ checkpointing:
108
+ monitor: "val/xy_recall_1m"
109
+ save_top_k: 30
110
+ mode: max
111
+
112
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_FPN-resnet34WeightedEmbedding.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: false
36
+ flip: false
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v4
49
+ architecture: FPN
50
+ backbone:
51
+ encoder: resnet34
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ weighted_embedding: ImprovedAttentionEmbedding
68
+ num_classes:
69
+ all: 50
70
+ # ways: 10
71
+ # nodes: 33
72
+ backbone:
73
+ encoder: vgg19
74
+ pretrained: false
75
+ output_scales:
76
+ - 0
77
+ num_downsample: 3
78
+ decoder:
79
+ - 128
80
+ - 64
81
+ - 64
82
+ padding: replicate
83
+ unary_prior: false
84
+ bev_net:
85
+ num_blocks: 4
86
+ latent_dim: 128
87
+ output_dim: 8
88
+ confidence: true
89
+ experiment:
90
+ name: maplocanet_602_hub_FPN_norelu_resnet34_ImprovedAttentionEmbedding
91
+ gpus: 5
92
+ seed: 42
93
+ training:
94
+ lr: 0.0001
95
+ lr_scheduler:
96
+ name: StepLR
97
+ args:
98
+ step_size: 10
99
+ gamma: 0.1
100
+ finetune_from_checkpoint: null
101
+ trainer:
102
+ val_check_interval: 1000
103
+ log_every_n_steps: 100
104
+ # limit_val_batches: 1000
105
+ max_steps: 300000
106
+ devices: ${experiment.gpus}
107
+ checkpointing:
108
+ monitor: "val/xy_recall_1m"
109
+ save_top_k: 5
110
+ mode: max
111
+
112
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_FPN-resnet50.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: false
36
+ flip: false
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v4
49
+ architecture: FPN
50
+ backbone:
51
+ encoder: resnet50
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ num_classes:
68
+ all: 50
69
+ # ways: 10
70
+ # nodes: 33
71
+ backbone:
72
+ encoder: vgg19
73
+ pretrained: false
74
+ output_scales:
75
+ - 0
76
+ num_downsample: 3
77
+ decoder:
78
+ - 128
79
+ - 64
80
+ - 64
81
+ padding: replicate
82
+ unary_prior: false
83
+ bev_net:
84
+ num_blocks: 4
85
+ latent_dim: 128
86
+ output_dim: 8
87
+ confidence: true
88
+ experiment:
89
+ name: maplocanet_602_hub_FPN_norelu_resnet50_temp
90
+ gpus: 2
91
+ seed: 42
92
+ training:
93
+ lr: 0.0001
94
+ lr_scheduler:
95
+ name: StepLR
96
+ args:
97
+ step_size: 10
98
+ gamma: 0.1
99
+ finetune_from_checkpoint: null
100
+ trainer:
101
+ val_check_interval: 1000
102
+ log_every_n_steps: 100
103
+ # limit_val_batches: 1000
104
+ max_steps: 300000
105
+ devices: ${experiment.gpus}
106
+ checkpointing:
107
+ monitor: "val/xy_recall_1m"
108
+ save_top_k: 5
109
+ mode: max
110
+
111
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_FPN-resnet50WeightedEmbedding.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: false
36
+ flip: false
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v4
49
+ architecture: FPN
50
+ backbone:
51
+ encoder: resnet50
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ weighted_embedding: ImprovedAttentionEmbedding
68
+ num_classes:
69
+ all: 50
70
+ # ways: 10
71
+ # nodes: 33
72
+ backbone:
73
+ encoder: vgg19
74
+ pretrained: false
75
+ output_scales:
76
+ - 0
77
+ num_downsample: 3
78
+ decoder:
79
+ - 128
80
+ - 64
81
+ - 64
82
+ padding: replicate
83
+ unary_prior: false
84
+ bev_net:
85
+ num_blocks: 4
86
+ latent_dim: 128
87
+ output_dim: 8
88
+ confidence: true
89
+ experiment:
90
+ name: maplocanet_602_hub_FPN_norelu_resnet50_ImprovedAttentionEmbedding
91
+ gpus: 3
92
+ seed: 42
93
+ training:
94
+ lr: 0.0001
95
+ lr_scheduler:
96
+ name: StepLR
97
+ args:
98
+ step_size: 10
99
+ gamma: 0.1
100
+ finetune_from_checkpoint: null
101
+ trainer:
102
+ val_check_interval: 1000
103
+ log_every_n_steps: 100
104
+ # limit_val_batches: 1000
105
+ max_steps: 300000
106
+ devices: ${experiment.gpus}
107
+ checkpointing:
108
+ monitor: "val/xy_recall_1m"
109
+ save_top_k: 5
110
+ mode: max
111
+
112
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_FPN.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v4
49
+ architecture: FPN
50
+ backbone:
51
+ encoder: resnet101
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ num_classes:
68
+ all: 50
69
+ # ways: 10
70
+ # nodes: 33
71
+ backbone:
72
+ encoder: vgg19
73
+ pretrained: false
74
+ output_scales:
75
+ - 0
76
+ num_downsample: 3
77
+ decoder:
78
+ - 128
79
+ - 64
80
+ - 64
81
+ padding: replicate
82
+ unary_prior: false
83
+ bev_net:
84
+ num_blocks: 4
85
+ latent_dim: 128
86
+ output_dim: 8
87
+ confidence: true
88
+ experiment:
89
+ name: maplocanet_602_hub_FPN_Resnet50_norelu
90
+ gpus: 2
91
+ seed: 0
92
+ training:
93
+ lr: 0.0001
94
+ lr_scheduler: null
95
+ finetune_from_checkpoint: null
96
+ trainer:
97
+ val_check_interval: 1000
98
+ log_every_n_steps: 100
99
+ # limit_val_batches: 1000
100
+ max_steps: 200000
101
+ devices: ${experiment.gpus}
102
+ checkpointing:
103
+ monitor: "val/xy_recall_1m"
104
+ save_top_k: 10
105
+ mode: min
106
+
107
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_FPN_Mobileone.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v4
49
+ architecture: FPN
50
+ backbone:
51
+ encoder: mobileone_s3
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ num_classes:
68
+ all: 50
69
+ # ways: 10
70
+ # nodes: 33
71
+ backbone:
72
+ encoder: vgg19
73
+ pretrained: false
74
+ output_scales:
75
+ - 0
76
+ num_downsample: 3
77
+ decoder:
78
+ - 128
79
+ - 64
80
+ - 64
81
+ padding: replicate
82
+ unary_prior: false
83
+ bev_net:
84
+ num_blocks: 4
85
+ latent_dim: 128
86
+ output_dim: 8
87
+ confidence: true
88
+ experiment:
89
+ name: maplocnetsinglhub_FPN_mobileone_s3
90
+ gpus: 2
91
+ seed: 0
92
+ training:
93
+ lr: 0.0001
94
+ lr_scheduler: null
95
+ finetune_from_checkpoint: null
96
+ trainer:
97
+ val_check_interval: 1000
98
+ log_every_n_steps: 100
99
+ # limit_val_batches: 1000
100
+ max_steps: 300000
101
+ devices: ${experiment.gpus}
102
+ checkpointing:
103
+ monitor: "val/xy_recall_1m"
104
+ save_top_k: 10
105
+ mode: max
106
+
107
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/maplocnetsinglhub_PSP.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/autodl-fs/DATASET/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ test_citys:
15
+ - SanFrancisco
16
+ image_size: 256
17
+ train:
18
+ batch_size: 12
19
+ num_workers: 4
20
+ val:
21
+ batch_size: ${..train.batch_size}
22
+ num_workers: ${.batch_size}
23
+ num_classes:
24
+ areas: 7
25
+ ways: 10
26
+ nodes: 33
27
+ pixel_per_meter: 1
28
+ crop_size_meters: 64
29
+ max_init_error: 48
30
+ add_map_mask: true
31
+ resize_image: 512
32
+ pad_to_square: true
33
+ rectify_pitch: true
34
+ augmentation:
35
+ rot90: true
36
+ flip: true
37
+ image:
38
+ apply: True
39
+ brightness: 0.5
40
+ contrast: 0.4
41
+ saturation: 0.4
42
+ hue": 0.5/3.14
43
+ model:
44
+ image_size: ${data.image_size}
45
+ latent_dim: 128
46
+ val_citys: ${data.val_citys}
47
+ image_encoder:
48
+ name: feature_extractor_v4
49
+ architecture: PSP
50
+ backbone:
51
+ encoder: resnet50
52
+ # pretrained: true
53
+ output_dim: 8
54
+ # upsampling: 2
55
+ # num_downsample: null
56
+ # remove_stride_from_first_conv: false
57
+ name: maplocnet
58
+ matching_dim: 8
59
+ z_max: 32
60
+ x_max: 32
61
+ pixel_per_meter: 1
62
+ num_scale_bins: 33
63
+ num_rotations: 64
64
+ map_encoder:
65
+ embedding_dim: 48
66
+ output_dim: 8
67
+ num_classes:
68
+ all: 50
69
+ # ways: 10
70
+ # nodes: 33
71
+ backbone:
72
+ encoder: vgg19
73
+ pretrained: false
74
+ output_scales:
75
+ - 0
76
+ num_downsample: 3
77
+ decoder:
78
+ - 128
79
+ - 64
80
+ - 64
81
+ padding: replicate
82
+ unary_prior: false
83
+ bev_net:
84
+ num_blocks: 4
85
+ latent_dim: 128
86
+ output_dim: 8
87
+ confidence: true
88
+ experiment:
89
+ name: maplocanet_602_hub_PSP
90
+ gpus: 2
91
+ seed: 0
92
+ training:
93
+ lr: 0.0001
94
+ lr_scheduler: null
95
+ finetune_from_checkpoint: null
96
+ trainer:
97
+ val_check_interval: 1000
98
+ log_every_n_steps: 100
99
+ # limit_val_batches: 1000
100
+ max_steps: 300000
101
+ devices: ${experiment.gpus}
102
+ checkpointing:
103
+ monitor: "val/xy_recall_1m"
104
+ save_top_k: 5
105
+ mode: max
106
+
107
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
conf/orienternet.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/home/ubuntu/media/MapLocNetDataset/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ # - Taipei
11
+ # - LosAngeles
12
+ # - Singapore
13
+ - SanFrancisco
14
+ image_size: 256
15
+ train:
16
+ batch_size: 12
17
+ num_workers: 4
18
+ val:
19
+ batch_size: ${..train.batch_size}
20
+ num_workers: ${.batch_size}
21
+ num_classes:
22
+ areas: 7
23
+ ways: 10
24
+ nodes: 33
25
+ pixel_per_meter: 1
26
+ crop_size_meters: 64
27
+ max_init_error: 48
28
+ add_map_mask: true
29
+ resize_image: 512
30
+ pad_to_square: true
31
+ rectify_pitch: true
32
+ augmentation:
33
+ rot90: true
34
+ # flip: true
35
+ image:
36
+ apply: true
37
+ brightness: 0.5
38
+ contrast: 0.4
39
+ saturation: 0.4
40
+ hue": 0.5/3.14
41
+ model:
42
+ image_size: ${data.image_size}
43
+ latent_dim: 128
44
+ val_citys: ${data.val_citys}
45
+ image_encoder:
46
+ name: feature_extractor_v2
47
+ backbone:
48
+ encoder: resnet101
49
+ pretrained: true
50
+ output_dim: 8
51
+ num_downsample: null
52
+ remove_stride_from_first_conv: false
53
+ name: orienternet
54
+ matching_dim: 8
55
+ z_max: 32
56
+ x_max: 32
57
+ pixel_per_meter: 1
58
+ num_scale_bins: 33
59
+ num_rotations: 64
60
+ map_encoder:
61
+ embedding_dim: 16
62
+ output_dim: 8
63
+ num_classes:
64
+ areas: 7
65
+ ways: 10
66
+ nodes: 33
67
+ backbone:
68
+ encoder: vgg19
69
+ pretrained: false
70
+ output_scales:
71
+ - 0
72
+ num_downsample: 3
73
+ decoder:
74
+ - 128
75
+ - 64
76
+ - 64
77
+ padding: replicate
78
+ unary_prior: false
79
+ bev_net:
80
+ num_blocks: 4
81
+ latent_dim: 128
82
+ output_dim: 8
83
+ confidence: true
84
+ experiment:
85
+ name: OrienterNet_my_multi_city_debug_code_0815_2_monitor_metric
86
+ gpus: 4
87
+ seed: 0
88
+ training:
89
+ lr: 0.0001
90
+ lr_scheduler: null
91
+ finetune_from_checkpoint: null
92
+ trainer:
93
+ val_check_interval: 1000
94
+ log_every_n_steps: 100
95
+ # limit_val_batches: 1000
96
+ max_steps: 200000
97
+ devices: ${experiment.gpus}
98
+ checkpointing:
99
+ monitor: "loss/total/val"
100
+ save_top_k: 10
101
+ mode: min
102
+
103
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
dataset/UAV/dataset.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import os
4
+ import cv2
5
+ # @Time : 2023-02-13 22:56
6
+ # @Author : Wang Zhen
7
+ # @Email : [email protected]
8
+ # @File : SatelliteTool.py
9
+ # @Project : TGRS_seqmatch_2023_1
10
+ import numpy as np
11
+ import random
12
+ from utils.geo import BoundaryBox, Projection
13
+ from osm.tiling import TileManager,MapTileManager
14
+ from pathlib import Path
15
+ from torchvision import transforms
16
+ from torch.utils.data import DataLoader
17
+
18
+ class UavMapPair(Dataset):
19
+ def __init__(
20
+ self,
21
+ root: Path,
22
+ city:str,
23
+ training:bool,
24
+ transform
25
+ ):
26
+ super().__init__()
27
+
28
+ # self.root = root
29
+
30
+ # city = 'Manhattan'
31
+ # root = '/root/DATASET/CrossModel/'
32
+ # root=Path(root)
33
+ self.uav_image_path = root/city/'uav'
34
+ self.map_path = root/city/'map'
35
+ self.map_vis = root / city / 'map_vis'
36
+ info_path = root / city / 'info.csv'
37
+
38
+ self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)
39
+
40
+ self.transform=transform
41
+ self.training=training
42
+
43
+ def random_center_crop(self,image):
44
+ height, width = image.shape[:2]
45
+
46
+ # 随机生成剪裁尺寸
47
+ crop_size = random.randint(min(height, width) // 2, min(height, width))
48
+
49
+ # 计算剪裁的起始坐标
50
+ start_x = (width - crop_size) // 2
51
+ start_y = (height - crop_size) // 2
52
+
53
+ # 进行剪裁
54
+ cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size]
55
+
56
+ return cropped_image
57
+ def __getitem__(self, index: int):
58
+ id, uav_name, map_name, \
59
+ uav_long, uav_lat, \
60
+ map_long, map_lat, \
61
+ tile_size_meters, pixel_per_meter, \
62
+ u, v, yaw,dis=self.info[index]
63
+
64
+
65
+ uav_image=cv2.imread(str(self.uav_image_path/uav_name))
66
+ if self.training:
67
+ uav_image =self.random_center_crop(uav_image)
68
+ uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB)
69
+ if self.transform:
70
+ uav_image=self.transform(uav_image)
71
+ map=np.load(str(self.map_path/map_name))
72
+
73
+ return {
74
+ 'map':torch.from_numpy(np.ascontiguousarray(map)).long(),
75
+ 'image':torch.tensor(uav_image),
76
+ 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(),
77
+ 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(),
78
+ "uv":torch.tensor([float(u), float(v)]).float(),
79
+ }
80
+ def __len__(self):
81
+ return len(self.info)
82
+ if __name__ == '__main__':
83
+
84
+ root=Path('/root/DATASET/OrienterNet/UavMap/')
85
+ city='NewYork'
86
+
87
+ transform = transforms.Compose([
88
+ transforms.ToTensor(),
89
+ transforms.Resize(256),
90
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
91
+ ])
92
+
93
+ dataset=UavMapPair(
94
+ root=root,
95
+ city=city,
96
+ transform=transform
97
+ )
98
+ datasetloder = DataLoader(dataset, batch_size=3)
99
+ for batch, i in enumerate(datasetloder):
100
+ pass
101
+ # 将PyTorch张量转换为PIL图像
102
+ # pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy())
103
+
104
+ # 显示图像
105
+ # 将PyTorch张量转换为NumPy数组
106
+ # numpy_array = i['uav_image'][0].numpy()
107
+ #
108
+ # # 显示图像
109
+ # plt.imshow(numpy_array.transpose(1, 2, 0))
110
+ # plt.axis('off')
111
+ # plt.show()
112
+ #
113
+ # map_viz, label = Colormap.apply(i['map'][0])
114
+ # map_viz = map_viz * 255
115
+ # map_viz = map_viz.astype(np.uint8)
116
+ # plot_images([map_viz], titles=["OpenStreetMap raster"])
dataset/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # from .UAV.dataset import UavMapPair
2
+ from .dataset import UavMapDatasetModule
3
+
4
+ # modules = {"UAV": UavMapPair}
dataset/dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
+ # from logger import logger
7
+ import numpy as np
8
+ # import torch
9
+ # import torch.utils.data as torchdata
10
+ # import torchvision.transforms as tvf
11
+ from omegaconf import DictConfig, OmegaConf
12
+ import pytorch_lightning as pl
13
+ from dataset.UAV.dataset import UavMapPair
14
+ # from torch.utils.data import Dataset, DataLoader
15
+ # from torchvision import transforms
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+ from torch.utils.data import Dataset, DataLoader, random_split
18
+ import torchvision.transforms as tvf
19
+
20
+ # 自定义数据模块类,继承自pl.LightningDataModule
21
+ class UavMapDatasetModule(pl.LightningDataModule):
22
+
23
+
24
+ def __init__(self, cfg: Dict[str, Any]):
25
+ super().__init__()
26
+
27
+ # default_cfg = OmegaConf.create(self.default_cfg)
28
+ # OmegaConf.set_struct(default_cfg, True) # cannot add new keys
29
+ # self.cfg = OmegaConf.merge(default_cfg, cfg)
30
+ self.cfg=cfg
31
+ # self.transform = tvf.Compose([
32
+ # tvf.ToTensor(),
33
+ # tvf.Resize(self.cfg.image_size),
34
+ # tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
35
+ # ])
36
+
37
+ tfs = []
38
+ tfs.append(tvf.ToTensor())
39
+ tfs.append(tvf.Resize(self.cfg.image_size))
40
+ self.val_tfs = tvf.Compose(tfs)
41
+
42
+ # transforms.Resize(self.cfg.image_size),
43
+ if cfg.augmentation.image.apply:
44
+ args = OmegaConf.masked_copy(
45
+ cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"]
46
+ )
47
+ tfs.append(tvf.ColorJitter(**args))
48
+ self.train_tfs = tvf.Compose(tfs)
49
+
50
+ # self.train_tfs=self.transform
51
+ # self.val_tfs = self.transform
52
+ self.init()
53
+ def init(self):
54
+ self.train_dataset = ConcatDataset([
55
+ UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.train_tfs)
56
+ for city in self.cfg.train_citys
57
+ ])
58
+
59
+ self.val_dataset = ConcatDataset([
60
+ UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs)
61
+ for city in self.cfg.val_citys
62
+ ])
63
+ self.test_dataset = ConcatDataset([
64
+ UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs)
65
+ for city in self.cfg.test_citys
66
+ ])
67
+
68
+ # self.val_datasets = {
69
+ # city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs)
70
+ # for city in self.cfg.val_citys
71
+ # }
72
+ # logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset)))
73
+ # # 定义分割比例
74
+ # train_ratio = 0.8 # 训练集比例
75
+ # # 计算分割的样本数量
76
+ # train_size = int(len(self.dataset) * train_ratio)
77
+ # val_size = len(self.dataset) - train_size
78
+ # self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size])
79
+ def train_dataloader(self):
80
+ train_loader = DataLoader(self.train_dataset,
81
+ batch_size=self.cfg.train.batch_size,
82
+ num_workers=self.cfg.train.num_workers,
83
+ shuffle=True,pin_memory = True)
84
+ return train_loader
85
+
86
+ def val_dataloader(self):
87
+ val_loader = DataLoader(self.val_dataset,
88
+ batch_size=self.cfg.val.batch_size,
89
+ num_workers=self.cfg.val.num_workers,
90
+ shuffle=True,pin_memory = True)
91
+ #
92
+ # my_dict = {k: v for k, v in self.val_datasets}
93
+ # val_loaders={city: DataLoader(dataset,
94
+ # batch_size=self.cfg.val.batch_size,
95
+ # num_workers=self.cfg.val.num_workers,
96
+ # shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()}
97
+ return val_loader
98
+ def test_dataloader(self):
99
+ val_loader = DataLoader(self.test_dataset,
100
+ batch_size=self.cfg.val.batch_size,
101
+ num_workers=self.cfg.val.num_workers,
102
+ shuffle=True,pin_memory = True)
103
+ #
104
+ # my_dict = {k: v for k, v in self.val_datasets}
105
+ # val_loaders={city: DataLoader(dataset,
106
+ # batch_size=self.cfg.val.batch_size,
107
+ # num_workers=self.cfg.val.num_workers,
108
+ # shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()}
109
+ return val_loader
dataset/image.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Callable, Optional, Union, Sequence
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms.functional as tvf
8
+ import collections
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ from utils.geometry import from_homogeneous, to_homogeneous
12
+ from utils.wrappers import Camera
13
+
14
+
15
+ def rectify_image(
16
+ image: torch.Tensor,
17
+ cam: Camera,
18
+ roll: float,
19
+ pitch: Optional[float] = None,
20
+ valid: Optional[torch.Tensor] = None,
21
+ ):
22
+ *_, h, w = image.shape
23
+ grid = torch.meshgrid(
24
+ [torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
25
+ indexing="xy",
26
+ )
27
+ grid = torch.stack(grid, -1).to(image.dtype)
28
+
29
+ if pitch is not None:
30
+ args = ("ZX", (roll, pitch))
31
+ else:
32
+ args = ("Z", roll)
33
+ R = Rotation.from_euler(*args, degrees=True).as_matrix()
34
+ R = torch.from_numpy(R).to(image)
35
+
36
+ grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
37
+ grid_rect = cam.denormalize(from_homogeneous(grid_rect))
38
+ grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
39
+ rectified = torch.nn.functional.grid_sample(
40
+ image[None],
41
+ grid_norm[None],
42
+ align_corners=False,
43
+ mode="bilinear",
44
+ ).squeeze(0)
45
+ if valid is None:
46
+ valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
47
+ else:
48
+ valid = (
49
+ torch.nn.functional.grid_sample(
50
+ valid[None, None].float(),
51
+ grid_norm[None],
52
+ align_corners=False,
53
+ mode="nearest",
54
+ )[0, 0]
55
+ > 0
56
+ )
57
+ return rectified, valid
58
+
59
+
60
+ def resize_image(
61
+ image: torch.Tensor,
62
+ size: Union[int, Sequence, np.ndarray],
63
+ fn: Optional[Callable] = None,
64
+ camera: Optional[Camera] = None,
65
+ valid: np.ndarray = None,
66
+ ):
67
+ """Resize an image to a fixed size, or according to max or min edge."""
68
+ *_, h, w = image.shape
69
+ if fn is not None:
70
+ assert isinstance(size, int)
71
+ scale = size / fn(h, w)
72
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
73
+ scale = (scale, scale)
74
+ else:
75
+ if isinstance(size, (collections.abc.Sequence, np.ndarray)):
76
+ w_new, h_new = size
77
+ elif isinstance(size, int):
78
+ w_new = h_new = size
79
+ else:
80
+ raise ValueError(f"Incorrect new size: {size}")
81
+ scale = (w_new / w, h_new / h)
82
+ if (w, h) != (w_new, h_new):
83
+ mode = tvf.InterpolationMode.BILINEAR
84
+ image = tvf.resize(image, (h_new, w_new), interpolation=mode, antialias=True)
85
+ image.clip_(0, 1)
86
+ if camera is not None:
87
+ camera = camera.scale(scale)
88
+ if valid is not None:
89
+ valid = tvf.resize(
90
+ valid.unsqueeze(0),
91
+ (h_new, w_new),
92
+ interpolation=tvf.InterpolationMode.NEAREST,
93
+ ).squeeze(0)
94
+ ret = [image, scale]
95
+ if camera is not None:
96
+ ret.append(camera)
97
+ if valid is not None:
98
+ ret.append(valid)
99
+ return ret
100
+
101
+
102
+ def pad_image(
103
+ image: torch.Tensor,
104
+ size: Union[int, Sequence, np.ndarray],
105
+ camera: Optional[Camera] = None,
106
+ valid: torch.Tensor = None,
107
+ crop_and_center: bool = False,
108
+ ):
109
+ if isinstance(size, int):
110
+ w_new = h_new = size
111
+ elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
112
+ w_new, h_new = size
113
+ else:
114
+ raise ValueError(f"Incorrect new size: {size}")
115
+ *c, h, w = image.shape
116
+ if crop_and_center:
117
+ diff = np.array([w - w_new, h - h_new])
118
+ left, top = left_top = np.round(diff / 2).astype(int)
119
+ right, bottom = diff - left_top
120
+ else:
121
+ assert h <= h_new
122
+ assert w <= w_new
123
+ top = bottom = left = right = 0
124
+ slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
125
+ slice_in = np.s_[
126
+ ..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
127
+ ]
128
+ if (w, h) == (w_new, h_new):
129
+ out = image
130
+ else:
131
+ out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
132
+ out[slice_out] = image[slice_in]
133
+ if camera is not None:
134
+ camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
135
+ out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
136
+ out_valid[slice_out] = True if valid is None else valid[slice_in]
137
+ if camera is not None:
138
+ return out, out_valid, camera
139
+ else:
140
+ return out, out_valid
dataset/torch.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import collections
4
+ import os
5
+
6
+ import torch
7
+ from torch.utils.data import get_worker_info
8
+ from torch.utils.data._utils.collate import (
9
+ default_collate_err_msg_format,
10
+ np_str_obj_array_pattern,
11
+ )
12
+ from lightning_fabric.utilities.seed import pl_worker_init_function
13
+ from lightning_utilities.core.apply_func import apply_to_collection
14
+ from lightning_fabric.utilities.apply_func import move_data_to_device
15
+
16
+
17
+ def collate(batch):
18
+ """Difference with PyTorch default_collate: it can stack other tensor-like objects.
19
+ Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
20
+ https://github.com/cvg/pixloc
21
+ Released under the Apache License 2.0
22
+ """
23
+ if not isinstance(batch, list): # no batching
24
+ return batch
25
+ elem = batch[0]
26
+ elem_type = type(elem)
27
+ if isinstance(elem, torch.Tensor):
28
+ out = None
29
+ if torch.utils.data.get_worker_info() is not None:
30
+ # If we're in a background process, concatenate directly into a
31
+ # shared memory tensor to avoid an extra copy
32
+ numel = sum(x.numel() for x in batch)
33
+ storage = elem.storage()._new_shared(numel, device=elem.device)
34
+ out = elem.new(storage).resize_(len(batch), *list(elem.size()))
35
+ return torch.stack(batch, 0, out=out)
36
+ elif (
37
+ elem_type.__module__ == "numpy"
38
+ and elem_type.__name__ != "str_"
39
+ and elem_type.__name__ != "string_"
40
+ ):
41
+ if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
42
+ # array of string classes and object
43
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
44
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
45
+
46
+ return collate([torch.as_tensor(b) for b in batch])
47
+ elif elem.shape == (): # scalars
48
+ return torch.as_tensor(batch)
49
+ elif isinstance(elem, float):
50
+ return torch.tensor(batch, dtype=torch.float64)
51
+ elif isinstance(elem, int):
52
+ return torch.tensor(batch)
53
+ elif isinstance(elem, (str, bytes)):
54
+ return batch
55
+ elif isinstance(elem, collections.abc.Mapping):
56
+ return {key: collate([d[key] for d in batch]) for key in elem}
57
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
58
+ return elem_type(*(collate(samples) for samples in zip(*batch)))
59
+ elif isinstance(elem, collections.abc.Sequence):
60
+ # check to make sure that the elements in batch have consistent size
61
+ it = iter(batch)
62
+ elem_size = len(next(it))
63
+ if not all(len(elem) == elem_size for elem in it):
64
+ raise RuntimeError("each element in list of batch should be of equal size")
65
+ transposed = zip(*batch)
66
+ return [collate(samples) for samples in transposed]
67
+ else:
68
+ # try to stack anyway in case the object implements stacking.
69
+ try:
70
+ return torch.stack(batch, 0)
71
+ except TypeError as e:
72
+ if "expected Tensor as element" in str(e):
73
+ return batch
74
+ else:
75
+ raise e
76
+
77
+
78
+ def set_num_threads(nt):
79
+ """Force numpy and other libraries to use a limited number of threads."""
80
+ try:
81
+ import mkl
82
+ except ImportError:
83
+ pass
84
+ else:
85
+ mkl.set_num_threads(nt)
86
+ torch.set_num_threads(1)
87
+ os.environ["IPC_ENABLE"] = "1"
88
+ for o in [
89
+ "OPENBLAS_NUM_THREADS",
90
+ "NUMEXPR_NUM_THREADS",
91
+ "OMP_NUM_THREADS",
92
+ "MKL_NUM_THREADS",
93
+ ]:
94
+ os.environ[o] = str(nt)
95
+
96
+
97
+ def worker_init_fn(i):
98
+ info = get_worker_info()
99
+ pl_worker_init_function(info.id)
100
+ num_threads = info.dataset.cfg.get("num_threads")
101
+ if num_threads is not None:
102
+ set_num_threads(num_threads)
103
+
104
+
105
+ def unbatch_to_device(data, device="cpu"):
106
+ data = move_data_to_device(data, device)
107
+ data = apply_to_collection(data, torch.Tensor, lambda x: x.squeeze(0))
108
+ data = apply_to_collection(
109
+ data, list, lambda x: x[0] if len(x) == 1 and isinstance(x[0], str) else x
110
+ )
111
+ return data
evaluation/kitti.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple
6
+
7
+ from omegaconf import OmegaConf, DictConfig
8
+
9
+ from .. import logger
10
+ from ..data import KittiDataModule
11
+ from .run import evaluate
12
+
13
+
14
+ default_cfg_single = OmegaConf.create({})
15
+ # For the sequential evaluation, we need to center the map around the GT location,
16
+ # since random offsets would accumulate and leave only the GT location with a valid mask.
17
+ # This should not have much impact on the results.
18
+ default_cfg_sequential = OmegaConf.create(
19
+ {
20
+ "data": {
21
+ "mask_radius": KittiDataModule.default_cfg["max_init_error"],
22
+ "prior_range_rotation": KittiDataModule.default_cfg[
23
+ "max_init_error_rotation"
24
+ ]
25
+ + 1,
26
+ "max_init_error": 0,
27
+ "max_init_error_rotation": 0,
28
+ },
29
+ "chunking": {
30
+ "max_length": 100, # about 10s?
31
+ },
32
+ }
33
+ )
34
+
35
+
36
+ def run(
37
+ split: str,
38
+ experiment: str,
39
+ cfg: Optional[DictConfig] = None,
40
+ sequential: bool = False,
41
+ thresholds: Tuple[int] = (1, 3, 5),
42
+ **kwargs,
43
+ ):
44
+ cfg = cfg or {}
45
+ if isinstance(cfg, dict):
46
+ cfg = OmegaConf.create(cfg)
47
+ default = default_cfg_sequential if sequential else default_cfg_single
48
+ cfg = OmegaConf.merge(default, cfg)
49
+ dataset = KittiDataModule(cfg.get("data", {}))
50
+
51
+ metrics = evaluate(
52
+ experiment,
53
+ cfg,
54
+ dataset,
55
+ split=split,
56
+ sequential=sequential,
57
+ viz_kwargs=dict(show_dir_error=True, show_masked_prob=False),
58
+ **kwargs,
59
+ )
60
+
61
+ keys = ["directional_error", "yaw_max_error"]
62
+ if sequential:
63
+ keys += ["directional_seq_error", "yaw_seq_error"]
64
+ for k in keys:
65
+ rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
66
+ logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
67
+ return metrics
68
+
69
+
70
+ if __name__ == "__main__":
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument("--experiment", type=str, required=True)
73
+ parser.add_argument(
74
+ "--split", type=str, default="test", choices=["test", "val", "train"]
75
+ )
76
+ parser.add_argument("--sequential", action="store_true")
77
+ parser.add_argument("--output_dir", type=Path)
78
+ parser.add_argument("--num", type=int)
79
+ parser.add_argument("dotlist", nargs="*")
80
+ args = parser.parse_args()
81
+ cfg = OmegaConf.from_cli(args.dotlist)
82
+ run(
83
+ args.split,
84
+ args.experiment,
85
+ cfg,
86
+ args.sequential,
87
+ output_dir=args.output_dir,
88
+ num=args.num,
89
+ )
evaluation/mapillary.py ADDED
File without changes
evaluation/run.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import functools
4
+ from itertools import islice
5
+ from typing import Callable, Dict, Optional, Tuple
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from torchmetrics import MetricCollection
12
+ from pytorch_lightning import seed_everything
13
+ from tqdm import tqdm
14
+
15
+ from logger import logger, EXPERIMENTS_PATH
16
+ from dataset.torch import collate, unbatch_to_device
17
+ from models.voting import argmax_xyr, fuse_gps
18
+ from models.metrics import AngleError, LateralLongitudinalError, Location2DError
19
+ from models.sequential import GPSAligner, RigidAligner
20
+ from module import GenericModule
21
+ from utils.io import download_file, DATA_URL
22
+ from evaluation.viz import plot_example_single, plot_example_sequential
23
+ from evaluation.utils import write_dump
24
+
25
+
26
+ pretrained_models = dict(
27
+ OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
28
+ )
29
+
30
+
31
+ def resolve_checkpoint_path(experiment_or_path: str) -> Path:
32
+ path = Path(experiment_or_path)
33
+ if not path.exists():
34
+ # provided name of experiment
35
+ path = Path(EXPERIMENTS_PATH, *experiment_or_path.split("/"))
36
+ if not path.exists():
37
+ if experiment_or_path in set(p for p, _ in pretrained_models.values()):
38
+ download_file(f"{DATA_URL}/{experiment_or_path}", path)
39
+ else:
40
+ raise FileNotFoundError(path)
41
+ if path.is_file():
42
+ return path
43
+ # provided only the experiment name
44
+ maybe_path = path / "last-step-v1.ckpt"
45
+ if not maybe_path.exists():
46
+ maybe_path = path / "last.ckpt"
47
+ if not maybe_path.exists():
48
+ raise FileNotFoundError(f"Could not find any checkpoint in {path}.")
49
+ return maybe_path
50
+
51
+
52
+ @torch.no_grad()
53
+ def evaluate_single_image(
54
+ dataloader: torch.utils.data.DataLoader,
55
+ model: GenericModule,
56
+ num: Optional[int] = None,
57
+ callback: Optional[Callable] = None,
58
+ progress: bool = True,
59
+ mask_index: Optional[Tuple[int]] = None,
60
+ has_gps: bool = False,
61
+ ):
62
+ ppm = model.model.conf.pixel_per_meter
63
+ metrics = MetricCollection(model.model.metrics())
64
+ metrics["directional_error"] = LateralLongitudinalError(ppm)
65
+ if has_gps:
66
+ metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
67
+ metrics["xy_fused_error"] = Location2DError("uv_fused", ppm)
68
+ metrics["yaw_fused_error"] = AngleError("yaw_fused")
69
+ metrics = metrics.to(model.device)
70
+
71
+ for i, batch_ in enumerate(
72
+ islice(tqdm(dataloader, total=num, disable=not progress), num)
73
+ ):
74
+ batch = model.transfer_batch_to_device(batch_, model.device, i)
75
+ # Ablation: mask semantic classes
76
+ if mask_index is not None:
77
+ mask = batch["map"][0, mask_index[0]] == (mask_index[1] + 1)
78
+ batch["map"][0, mask_index[0]][mask] = 0
79
+ pred = model(batch)
80
+
81
+ if has_gps:
82
+ (uv_gps,) = pred["uv_gps"] = batch["uv_gps"]
83
+ pred["log_probs_fused"] = fuse_gps(
84
+ pred["log_probs"], uv_gps, ppm, sigma=batch["accuracy_gps"]
85
+ )
86
+ uvt_fused = argmax_xyr(pred["log_probs_fused"])
87
+ pred["uv_fused"] = uvt_fused[..., :2]
88
+ pred["yaw_fused"] = uvt_fused[..., -1]
89
+ del uv_gps, uvt_fused
90
+
91
+ results = metrics(pred, batch)
92
+ if callback is not None:
93
+ callback(
94
+ i, model, unbatch_to_device(pred), unbatch_to_device(batch_), results
95
+ )
96
+ del batch_, batch, pred, results
97
+
98
+ return metrics.cpu()
99
+
100
+
101
+ @torch.no_grad()
102
+ def evaluate_sequential(
103
+ dataset: torch.utils.data.Dataset,
104
+ chunk2idx: Dict,
105
+ model: GenericModule,
106
+ num: Optional[int] = None,
107
+ shuffle: bool = False,
108
+ callback: Optional[Callable] = None,
109
+ progress: bool = True,
110
+ num_rotations: int = 512,
111
+ mask_index: Optional[Tuple[int]] = None,
112
+ has_gps: bool = True,
113
+ ):
114
+ chunk_keys = list(chunk2idx)
115
+ if shuffle:
116
+ chunk_keys = [chunk_keys[i] for i in torch.randperm(len(chunk_keys))]
117
+ if num is not None:
118
+ chunk_keys = chunk_keys[:num]
119
+ lengths = [len(chunk2idx[k]) for k in chunk_keys]
120
+ logger.info(
121
+ "Min/max/med lengths: %d/%d/%d, total number of images: %d",
122
+ min(lengths),
123
+ np.median(lengths),
124
+ max(lengths),
125
+ sum(lengths),
126
+ )
127
+ viz = callback is not None
128
+
129
+ metrics = MetricCollection(model.model.metrics())
130
+ ppm = model.model.conf.pixel_per_meter
131
+ metrics["directional_error"] = LateralLongitudinalError(ppm)
132
+ metrics["xy_seq_error"] = Location2DError("uv_seq", ppm)
133
+ metrics["yaw_seq_error"] = AngleError("yaw_seq")
134
+ metrics["directional_seq_error"] = LateralLongitudinalError(ppm, key="uv_seq")
135
+ if has_gps:
136
+ metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
137
+ metrics["xy_gps_seq_error"] = Location2DError("uv_gps_seq", ppm)
138
+ metrics["yaw_gps_seq_error"] = AngleError("yaw_gps_seq")
139
+ metrics = metrics.to(model.device)
140
+
141
+ keys_save = ["uvr_max", "uv_max", "yaw_max", "uv_expectation"]
142
+ if has_gps:
143
+ keys_save.append("uv_gps")
144
+ if viz:
145
+ keys_save.append("log_probs")
146
+
147
+ for chunk_index, key in enumerate(tqdm(chunk_keys, disable=not progress)):
148
+ indices = chunk2idx[key]
149
+ aligner = RigidAligner(track_priors=viz, num_rotations=num_rotations)
150
+ if has_gps:
151
+ aligner_gps = GPSAligner(track_priors=viz, num_rotations=num_rotations)
152
+ batches = []
153
+ preds = []
154
+ for i in indices:
155
+ data = dataset[i]
156
+ data = model.transfer_batch_to_device(data, model.device, 0)
157
+ pred = model(collate([data]))
158
+
159
+ canvas = data["canvas"]
160
+ data["xy_geo"] = xy = canvas.to_xy(data["uv"].double())
161
+ data["yaw"] = yaw = data["roll_pitch_yaw"][-1].double()
162
+ aligner.update(pred["log_probs"][0], canvas, xy, yaw)
163
+
164
+ if has_gps:
165
+ (uv_gps) = pred["uv_gps"] = data["uv_gps"][None]
166
+ xy_gps = canvas.to_xy(uv_gps.double())
167
+ aligner_gps.update(xy_gps, data["accuracy_gps"], canvas, xy, yaw)
168
+
169
+ if not viz:
170
+ data.pop("image")
171
+ data.pop("map")
172
+ batches.append(data)
173
+ preds.append({k: pred[k][0] for k in keys_save})
174
+ del pred
175
+
176
+ xy_gt = torch.stack([b["xy_geo"] for b in batches])
177
+ yaw_gt = torch.stack([b["yaw"] for b in batches])
178
+ aligner.compute()
179
+ xy_seq, yaw_seq = aligner.transform(xy_gt, yaw_gt)
180
+ if has_gps:
181
+ aligner_gps.compute()
182
+ xy_gps_seq, yaw_gps_seq = aligner_gps.transform(xy_gt, yaw_gt)
183
+ results = []
184
+ for i in range(len(indices)):
185
+ preds[i]["uv_seq"] = batches[i]["canvas"].to_uv(xy_seq[i]).float()
186
+ preds[i]["yaw_seq"] = yaw_seq[i].float()
187
+ if has_gps:
188
+ preds[i]["uv_gps_seq"] = (
189
+ batches[i]["canvas"].to_uv(xy_gps_seq[i]).float()
190
+ )
191
+ preds[i]["yaw_gps_seq"] = yaw_gps_seq[i].float()
192
+ results.append(metrics(preds[i], batches[i]))
193
+ if viz:
194
+ callback(chunk_index, model, batches, preds, results, aligner)
195
+ del aligner, preds, batches, results
196
+ return metrics.cpu()
197
+
198
+
199
+ def evaluate(
200
+ experiment: str,
201
+ cfg: DictConfig,
202
+ dataset,
203
+ split: str,
204
+ sequential: bool = False,
205
+ output_dir: Optional[Path] = None,
206
+ callback: Optional[Callable] = None,
207
+ num_workers: int = 1,
208
+ viz_kwargs=None,
209
+ **kwargs,
210
+ ):
211
+ if experiment in pretrained_models:
212
+ experiment, cfg_override = pretrained_models[experiment]
213
+ cfg = OmegaConf.merge(OmegaConf.create(dict(model=cfg_override)), cfg)
214
+
215
+ logger.info("Evaluating model %s with config %s", experiment, cfg)
216
+ checkpoint_path = resolve_checkpoint_path(experiment)
217
+ model = GenericModule.load_from_checkpoint(
218
+ checkpoint_path, cfg=cfg, find_best=not experiment.endswith(".ckpt")
219
+ )
220
+ model = model.eval()
221
+ if torch.cuda.is_available():
222
+ model = model.cuda()
223
+
224
+ dataset.prepare_data()
225
+ dataset.setup()
226
+
227
+ if output_dir is not None:
228
+ output_dir.mkdir(exist_ok=True, parents=True)
229
+ if callback is None:
230
+ if sequential:
231
+ callback = plot_example_sequential
232
+ else:
233
+ callback = plot_example_single
234
+ callback = functools.partial(
235
+ callback, out_dir=output_dir, **(viz_kwargs or {})
236
+ )
237
+ kwargs = {**kwargs, "callback": callback}
238
+
239
+ seed_everything(dataset.cfg.seed)
240
+ if sequential:
241
+ dset, chunk2idx = dataset.sequence_dataset(split, **cfg.chunking)
242
+ metrics = evaluate_sequential(dset, chunk2idx, model, **kwargs)
243
+ else:
244
+ loader = dataset.dataloader(split, shuffle=True, num_workers=num_workers)
245
+ metrics = evaluate_single_image(loader, model, **kwargs)
246
+
247
+ results = metrics.compute()
248
+ logger.info("All results: %s", results)
249
+ if output_dir is not None:
250
+ write_dump(output_dir, experiment, cfg, results, metrics)
251
+ logger.info("Outputs have been written to %s.", output_dir)
252
+ return metrics
evaluation/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+
6
+ from utils.io import write_json
7
+
8
+
9
+ def compute_recall(errors):
10
+ num_elements = len(errors)
11
+ sort_idx = np.argsort(errors)
12
+ errors = np.array(errors.copy())[sort_idx]
13
+ recall = (np.arange(num_elements) + 1) / num_elements
14
+ recall = np.r_[0, recall]
15
+ errors = np.r_[0, errors]
16
+ return errors, recall
17
+
18
+
19
+ def compute_auc(errors, recall, thresholds):
20
+ aucs = []
21
+ for t in thresholds:
22
+ last_index = np.searchsorted(errors, t, side="right")
23
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
24
+ e = np.r_[errors[:last_index], t]
25
+ auc = np.trapz(r, x=e) / t
26
+ aucs.append(auc * 100)
27
+ return aucs
28
+
29
+
30
+ def write_dump(output_dir, experiment, cfg, results, metrics):
31
+ dump = {
32
+ "experiment": experiment,
33
+ "cfg": OmegaConf.to_container(cfg),
34
+ "results": results,
35
+ "errors": {},
36
+ }
37
+ for k, m in metrics.items():
38
+ if hasattr(m, "get_errors"):
39
+ dump["errors"][k] = m.get_errors().numpy()
40
+ write_json(output_dir / "log.json", dump)
evaluation/viz.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+
7
+ from utils.io import write_torch_image
8
+ from utils.viz_2d import plot_images, features_to_RGB, save_plot
9
+ from utils.viz_localization import (
10
+ likelihood_overlay,
11
+ plot_pose,
12
+ plot_dense_rotations,
13
+ add_circle_inset,
14
+ )
15
+ from osm.viz import Colormap, plot_nodes
16
+
17
+
18
+ def plot_example_single(
19
+ idx,
20
+ model,
21
+ pred,
22
+ data,
23
+ results,
24
+ plot_bev=True,
25
+ out_dir=None,
26
+ fig_for_paper=False,
27
+ show_gps=False,
28
+ show_fused=False,
29
+ show_dir_error=False,
30
+ show_masked_prob=False,
31
+ ):
32
+ scene, name, rasters, uv_gt = (data[k] for k in ("scene", "name", "map", "uv"))
33
+ uv_gps = data.get("uv_gps")
34
+ yaw_gt = data["roll_pitch_yaw"][-1].numpy()
35
+ image = data["image"].permute(1, 2, 0)
36
+ if "valid" in data:
37
+ image = image.masked_fill(~data["valid"].unsqueeze(-1), 0.3)
38
+
39
+ lp_uvt = lp_uv = pred["log_probs"]
40
+ if show_fused and "log_probs_fused" in pred:
41
+ lp_uvt = lp_uv = pred["log_probs_fused"]
42
+ elif not show_masked_prob and "scores_unmasked" in pred:
43
+ lp_uvt = lp_uv = pred["scores_unmasked"]
44
+ has_rotation = lp_uvt.ndim == 3
45
+ if has_rotation:
46
+ lp_uv = lp_uvt.max(-1).values
47
+ if lp_uv.min() > -np.inf:
48
+ lp_uv = lp_uv.clip(min=np.percentile(lp_uv, 1))
49
+ prob = lp_uv.exp()
50
+ uv_p, yaw_p = pred["uv_max"], pred.get("yaw_max")
51
+ if show_fused and "uv_fused" in pred:
52
+ uv_p, yaw_p = pred["uv_fused"], pred.get("yaw_fused")
53
+ feats_map = pred["map"]["map_features"][0]
54
+ (feats_map_rgb,) = features_to_RGB(feats_map.numpy())
55
+
56
+ text1 = rf'$\Delta xy$: {results["xy_max_error"]:.1f}m'
57
+ if has_rotation:
58
+ text1 += rf', $\Delta\theta$: {results["yaw_max_error"]:.1f}°'
59
+ if show_fused and "xy_fused_error" in results:
60
+ text1 += rf', $\Delta xy_{{fused}}$: {results["xy_fused_error"]:.1f}m'
61
+ text1 += rf', $\Delta\theta_{{fused}}$: {results["yaw_fused_error"]:.1f}°'
62
+ if show_dir_error and "directional_error" in results:
63
+ err_lat, err_lon = results["directional_error"]
64
+ text1 += rf", $\Delta$lateral/longitundinal={err_lat:.1f}m/{err_lon:.1f}m"
65
+ if "xy_gps_error" in results:
66
+ text1 += rf', $\Delta xy_{{GPS}}$: {results["xy_gps_error"]:.1f}m'
67
+
68
+ map_viz = Colormap.apply(rasters)
69
+ overlay = likelihood_overlay(prob.numpy(), map_viz.mean(-1, keepdims=True))
70
+ plot_images(
71
+ [image, map_viz, overlay, feats_map_rgb],
72
+ titles=[text1, "map", "likelihood", "neural map"],
73
+ dpi=75,
74
+ cmaps="jet",
75
+ )
76
+ fig = plt.gcf()
77
+ axes = fig.axes
78
+ axes[1].images[0].set_interpolation("none")
79
+ axes[2].images[0].set_interpolation("none")
80
+ Colormap.add_colorbar()
81
+ plot_nodes(1, rasters[2])
82
+
83
+ if show_gps and uv_gps is not None:
84
+ plot_pose([1], uv_gps, c="blue")
85
+ plot_pose([1], uv_gt, yaw_gt, c="red")
86
+ plot_pose([1], uv_p, yaw_p, c="k")
87
+ plot_dense_rotations(2, lp_uvt.exp())
88
+ inset_center = pred["uv_max"] if results["xy_max_error"] < 5 else uv_gt
89
+ axins = add_circle_inset(axes[2], inset_center)
90
+ axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50, zorder=15)
91
+ axes[0].text(
92
+ 0.003,
93
+ 0.003,
94
+ f"{scene}/{name}",
95
+ transform=axes[0].transAxes,
96
+ fontsize=3,
97
+ va="bottom",
98
+ ha="left",
99
+ color="w",
100
+ )
101
+ plt.show()
102
+ if out_dir is not None:
103
+ name_ = name.replace("/", "_")
104
+ p = str(out_dir / f"{scene}_{name_}_{{}}.pdf")
105
+ save_plot(p.format("pred"))
106
+ plt.close()
107
+
108
+ if fig_for_paper:
109
+ # !cp ../datasets/MGL/{scene}/images/{name}.jpg {out_dir}/{scene}_{name}.jpg
110
+ plot_images([map_viz])
111
+ plt.gca().images[0].set_interpolation("none")
112
+ plot_nodes(0, rasters[2])
113
+ plot_pose([0], uv_gt, yaw_gt, c="red")
114
+ plot_pose([0], pred["uv_max"], pred["yaw_max"], c="k")
115
+ save_plot(p.format("map"))
116
+ plt.close()
117
+ plot_images([lp_uv], cmaps="jet")
118
+ plot_dense_rotations(0, lp_uvt.exp())
119
+ save_plot(p.format("loglikelihood"), dpi=100)
120
+ plt.close()
121
+ plot_images([overlay])
122
+ plt.gca().images[0].set_interpolation("none")
123
+ axins = add_circle_inset(plt.gca(), inset_center)
124
+ axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50)
125
+ save_plot(p.format("likelihood"))
126
+ plt.close()
127
+ write_torch_image(
128
+ p.format("neuralmap").replace("pdf", "jpg"), feats_map_rgb
129
+ )
130
+ write_torch_image(p.format("image").replace("pdf", "jpg"), image.numpy())
131
+
132
+ if not plot_bev:
133
+ return
134
+
135
+ feats_q = pred["features_bev"]
136
+ mask_bev = pred["valid_bev"]
137
+ prior = None
138
+ if "log_prior" in pred["map"]:
139
+ prior = pred["map"]["log_prior"][0].sigmoid()
140
+ if "bev" in pred and "confidence" in pred["bev"]:
141
+ conf_q = pred["bev"]["confidence"]
142
+ else:
143
+ conf_q = torch.norm(feats_q, dim=0)
144
+ conf_q = conf_q.masked_fill(~mask_bev, np.nan)
145
+ (feats_q_rgb,) = features_to_RGB(feats_q.numpy(), masks=[mask_bev.numpy()])
146
+ # feats_map_rgb, feats_q_rgb, = features_to_RGB(
147
+ # feats_map.numpy(), feats_q.numpy(), masks=[None, mask_bev])
148
+ norm_map = torch.norm(feats_map, dim=0)
149
+
150
+ plot_images(
151
+ [conf_q, feats_q_rgb, norm_map] + ([] if prior is None else [prior]),
152
+ titles=["BEV confidence", "BEV features", "map norm"]
153
+ + ([] if prior is None else ["map prior"]),
154
+ dpi=50,
155
+ cmaps="jet",
156
+ )
157
+ plt.show()
158
+
159
+ if out_dir is not None:
160
+ save_plot(p.format("bev"))
161
+ plt.close()
162
+
163
+
164
+ def plot_example_sequential(
165
+ idx,
166
+ model,
167
+ pred,
168
+ data,
169
+ results,
170
+ plot_bev=True,
171
+ out_dir=None,
172
+ fig_for_paper=False,
173
+ show_gps=False,
174
+ show_fused=False,
175
+ show_dir_error=False,
176
+ show_masked_prob=False,
177
+ ):
178
+ return
feature_extractor_models/__init__.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from . import encoders
3
+ from . import decoders
4
+
5
+
6
+ from .decoders.unet import Unet
7
+ from .decoders.unetplusplus import UnetPlusPlus
8
+ from .decoders.manet import MAnet
9
+ from .decoders.linknet import Linknet
10
+ from .decoders.fpn import FPN
11
+ from .decoders.lightfpn import LightFPN
12
+ from .decoders.pspnet import PSPNet
13
+ from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
14
+ from .decoders.pan import PAN
15
+ from .base.hub_mixin import from_pretrained
16
+
17
+ from .__version__ import __version__
18
+
19
+ # some private imports for create_model function
20
+ from typing import Optional as _Optional
21
+ import torch as _torch
22
+
23
+
24
+ def create_model(
25
+ arch: str,
26
+ encoder_name: str = "resnet34",
27
+ encoder_weights: _Optional[str] = "imagenet",
28
+ in_channels: int = 3,
29
+ classes: int = 1,
30
+ **kwargs,
31
+ ) -> _torch.nn.Module:
32
+ """Models entrypoint, allows to create any model architecture just with
33
+ parameters, without using its class
34
+ """
35
+
36
+ archs = [
37
+ Unet,
38
+ UnetPlusPlus,
39
+ MAnet,
40
+ Linknet,
41
+ FPN,
42
+ LightFPN,
43
+ PSPNet,
44
+ DeepLabV3,
45
+ DeepLabV3Plus,
46
+ PAN,
47
+ ]
48
+ archs_dict = {a.__name__.lower(): a for a in archs}
49
+ try:
50
+ model_class = archs_dict[arch.lower()]
51
+ except KeyError:
52
+ raise KeyError(
53
+ "Wrong architecture type `{}`. Available options are: {}".format(
54
+ arch, list(archs_dict.keys())
55
+ )
56
+ )
57
+ return model_class(
58
+ encoder_name=encoder_name,
59
+ encoder_weights=encoder_weights,
60
+ in_channels=in_channels,
61
+ classes=classes,
62
+ **kwargs,
63
+ )
64
+
65
+
66
+ __all__ = [
67
+ "encoders",
68
+ "decoders",
69
+ "Unet",
70
+ "UnetPlusPlus",
71
+ "MAnet",
72
+ "Linknet",
73
+ "FPN",
74
+ "LightFPN",
75
+ "PSPNet",
76
+ "DeepLabV3",
77
+ "DeepLabV3Plus",
78
+ "PAN",
79
+ "from_pretrained",
80
+ "create_model",
81
+ "__version__",
82
+ ]
feature_extractor_models/__version__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ VERSION = (0, 3, "4dev0")
2
+
3
+ __version__ = ".".join(map(str, VERSION))
feature_extractor_models/base/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import SegmentationModel
2
+
3
+ from .modules import Conv2dReLU, Attention
4
+
5
+ from .heads import SegmentationHead, ClassificationHead
6
+
7
+ __all__ = [
8
+ "SegmentationModel",
9
+ "Conv2dReLU",
10
+ "Attention",
11
+ "SegmentationHead",
12
+ "ClassificationHead",
13
+ ]
feature_extractor_models/base/heads.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .modules import Activation
3
+
4
+
5
+ class SegmentationHead(nn.Sequential):
6
+ def __init__(
7
+ self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
8
+ ):
9
+ conv2d = nn.Conv2d(
10
+ in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
11
+ )
12
+ upsampling = (
13
+ nn.UpsamplingBilinear2d(scale_factor=upsampling)
14
+ if upsampling > 1
15
+ else nn.Identity()
16
+ )
17
+ activation = Activation(activation)
18
+ super().__init__(conv2d, upsampling, activation)
19
+
20
+
21
+ class ClassificationHead(nn.Sequential):
22
+ def __init__(
23
+ self, in_channels, classes, pooling="avg", dropout=0.2, activation=None
24
+ ):
25
+ if pooling not in ("max", "avg"):
26
+ raise ValueError(
27
+ "Pooling should be one of ('max', 'avg'), got {}.".format(pooling)
28
+ )
29
+ pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
30
+ flatten = nn.Flatten()
31
+ dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
32
+ linear = nn.Linear(in_channels, classes, bias=True)
33
+ activation = Activation(activation)
34
+ super().__init__(pool, flatten, dropout, linear, activation)
feature_extractor_models/base/hub_mixin.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+ from functools import wraps
5
+ from huggingface_hub import (
6
+ PyTorchModelHubMixin,
7
+ ModelCard,
8
+ ModelCardData,
9
+ hf_hub_download,
10
+ )
11
+
12
+
13
+ MODEL_CARD = """
14
+ ---
15
+ {{ card_data }}
16
+ ---
17
+ # {{ model_name }} Model Card
18
+
19
+ Table of Contents:
20
+ - [Load trained model](#load-trained-model)
21
+ - [Model init parameters](#model-init-parameters)
22
+ - [Model metrics](#model-metrics)
23
+ - [Dataset](#dataset)
24
+
25
+ ## Load trained model
26
+ ```python
27
+ import feature_extractor_models as smp
28
+
29
+ model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("<save-directory-or-repo>", true)}}")
30
+ ```
31
+
32
+ ## Model init parameters
33
+ ```python
34
+ model_init_params = {{ model_parameters }}
35
+ ```
36
+
37
+ ## Model metrics
38
+ {{ metrics | default("[More Information Needed]", true) }}
39
+
40
+ ## Dataset
41
+ Dataset name: {{ dataset | default("[More Information Needed]", true) }}
42
+
43
+ ## More Information
44
+ - Library: {{ repo_url | default("[More Information Needed]", true) }}
45
+ - Docs: {{ docs_url | default("[More Information Needed]", true) }}
46
+
47
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)
48
+ """
49
+
50
+
51
+ def _format_parameters(parameters: dict):
52
+ params = {k: v for k, v in parameters.items() if not k.startswith("_")}
53
+ params = [
54
+ f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"'
55
+ for k, v in params.items()
56
+ ]
57
+ params = ",\n".join([f" {param}" for param in params])
58
+ params = "{\n" + f"{params}" + "\n}"
59
+ return params
60
+
61
+
62
+ class SMPHubMixin(PyTorchModelHubMixin):
63
+ def generate_model_card(self, *args, **kwargs) -> ModelCard:
64
+ model_parameters_json = _format_parameters(self._hub_mixin_config)
65
+ directory = self._save_directory if hasattr(self, "_save_directory") else None
66
+ repo_id = self._repo_id if hasattr(self, "_repo_id") else None
67
+ repo_or_directory = repo_id if repo_id is not None else directory
68
+
69
+ metrics = self._metrics if hasattr(self, "_metrics") else None
70
+ dataset = self._dataset if hasattr(self, "_dataset") else None
71
+
72
+ if metrics is not None:
73
+ metrics = json.dumps(metrics, indent=4)
74
+ metrics = f"```json\n{metrics}\n```"
75
+
76
+ model_card_data = ModelCardData(
77
+ languages=["python"],
78
+ library_name="segmentation-models-pytorch",
79
+ license="mit",
80
+ tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"],
81
+ pipeline_tag="image-segmentation",
82
+ )
83
+ model_card = ModelCard.from_template(
84
+ card_data=model_card_data,
85
+ template_str=MODEL_CARD,
86
+ repo_url="https://github.com/qubvel/segmentation_models.pytorch",
87
+ docs_url="https://smp.readthedocs.io/en/latest/",
88
+ model_parameters=model_parameters_json,
89
+ save_directory=repo_or_directory,
90
+ model_name=self.__class__.__name__,
91
+ metrics=metrics,
92
+ dataset=dataset,
93
+ )
94
+ return model_card
95
+
96
+ def _set_attrs_from_kwargs(self, attrs, kwargs):
97
+ for attr in attrs:
98
+ if attr in kwargs:
99
+ setattr(self, f"_{attr}", kwargs.pop(attr))
100
+
101
+ def _del_attrs(self, attrs):
102
+ for attr in attrs:
103
+ if hasattr(self, f"_{attr}"):
104
+ delattr(self, f"_{attr}")
105
+
106
+ @wraps(PyTorchModelHubMixin.save_pretrained)
107
+ def save_pretrained(
108
+ self, save_directory: Union[str, Path], *args, **kwargs
109
+ ) -> Optional[str]:
110
+ # set additional attributes to be used in generate_model_card
111
+ self._save_directory = save_directory
112
+ self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
113
+
114
+ # set additional attribute to be used in from_pretrained
115
+ self._hub_mixin_config["_model_class"] = self.__class__.__name__
116
+
117
+ try:
118
+ # call the original save_pretrained
119
+ result = super().save_pretrained(save_directory, *args, **kwargs)
120
+ finally:
121
+ # delete the additional attributes
122
+ self._del_attrs(["save_directory", "metrics", "dataset"])
123
+ self._hub_mixin_config.pop("_model_class")
124
+
125
+ return result
126
+
127
+ @wraps(PyTorchModelHubMixin.push_to_hub)
128
+ def push_to_hub(self, repo_id: str, *args, **kwargs):
129
+ self._repo_id = repo_id
130
+ self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
131
+ result = super().push_to_hub(repo_id, *args, **kwargs)
132
+ self._del_attrs(["repo_id", "metrics", "dataset"])
133
+ return result
134
+
135
+ @property
136
+ def config(self):
137
+ return self._hub_mixin_config
138
+
139
+
140
+ @wraps(PyTorchModelHubMixin.from_pretrained)
141
+ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
142
+ config_path = hf_hub_download(
143
+ pretrained_model_name_or_path,
144
+ filename="config.json",
145
+ revision=kwargs.get("revision", None),
146
+ )
147
+ with open(config_path, "r") as f:
148
+ config = json.load(f)
149
+ model_class_name = config.pop("_model_class")
150
+
151
+ import feature_extractor_models as smp
152
+
153
+ model_class = getattr(smp, model_class_name)
154
+ return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
feature_extractor_models/base/initialization.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def initialize_decoder(module):
5
+ for m in module.modules():
6
+ if isinstance(m, nn.Conv2d):
7
+ nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
8
+ if m.bias is not None:
9
+ nn.init.constant_(m.bias, 0)
10
+
11
+ elif isinstance(m, nn.BatchNorm2d):
12
+ nn.init.constant_(m.weight, 1)
13
+ nn.init.constant_(m.bias, 0)
14
+
15
+ elif isinstance(m, nn.Linear):
16
+ nn.init.xavier_uniform_(m.weight)
17
+ if m.bias is not None:
18
+ nn.init.constant_(m.bias, 0)
19
+
20
+
21
+ def initialize_head(module):
22
+ for m in module.modules():
23
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
24
+ nn.init.xavier_uniform_(m.weight)
25
+ if m.bias is not None:
26
+ nn.init.constant_(m.bias, 0)
feature_extractor_models/base/model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from . import initialization as init
4
+ from .hub_mixin import SMPHubMixin
5
+ import torch.nn as nn
6
+
7
+ class SegmentationModel(torch.nn.Module, SMPHubMixin):
8
+ def initialize(self):
9
+
10
+ # self.out = nn.Sequential(
11
+ # nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
12
+ # nn.BatchNorm2d(8),
13
+ # nn.ReLU(inplace=True),
14
+ # )
15
+ init.initialize_decoder(self.decoder)
16
+ init.initialize_head(self.segmentation_head)
17
+ if self.classification_head is not None:
18
+ init.initialize_head(self.classification_head)
19
+
20
+ def check_input_shape(self, x):
21
+ h, w = x.shape[-2:]
22
+ output_stride = self.encoder.output_stride
23
+ if h % output_stride != 0 or w % output_stride != 0:
24
+ new_h = (
25
+ (h // output_stride + 1) * output_stride
26
+ if h % output_stride != 0
27
+ else h
28
+ )
29
+ new_w = (
30
+ (w // output_stride + 1) * output_stride
31
+ if w % output_stride != 0
32
+ else w
33
+ )
34
+ raise RuntimeError(
35
+ f"Wrong input shape height={h}, width={w}. Expected image height and width "
36
+ f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
37
+ )
38
+
39
+ def forward(self, x):
40
+ """Sequentially pass `x` trough model`s encoder, decoder and heads"""
41
+
42
+ self.check_input_shape(x)
43
+
44
+ features = self.encoder(x)
45
+ decoder_output = self.decoder(*features)
46
+
47
+ decoder_output = self.segmentation_head(decoder_output)
48
+ #
49
+ # if self.classification_head is not None:
50
+ # labels = self.classification_head(features[-1])
51
+ # return masks, labels
52
+
53
+ return decoder_output
54
+
55
+ @torch.no_grad()
56
+ def predict(self, x):
57
+ """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
58
+
59
+ Args:
60
+ x: 4D torch tensor with shape (batch_size, channels, height, width)
61
+
62
+ Return:
63
+ prediction: 4D torch tensor with shape (batch_size, classes, height, width)
64
+
65
+ """
66
+ if self.training:
67
+ self.eval()
68
+
69
+ x = self.forward(x)
70
+
71
+ return x
feature_extractor_models/base/modules.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ try:
5
+ from inplace_abn import InPlaceABN
6
+ except ImportError:
7
+ InPlaceABN = None
8
+
9
+
10
+ class Conv2dReLU(nn.Sequential):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ padding=0,
17
+ stride=1,
18
+ use_batchnorm=True,
19
+ ):
20
+ if use_batchnorm == "inplace" and InPlaceABN is None:
21
+ raise RuntimeError(
22
+ "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
23
+ + "To install see: https://github.com/mapillary/inplace_abn"
24
+ )
25
+
26
+ conv = nn.Conv2d(
27
+ in_channels,
28
+ out_channels,
29
+ kernel_size,
30
+ stride=stride,
31
+ padding=padding,
32
+ bias=not (use_batchnorm),
33
+ )
34
+ relu = nn.ReLU(inplace=True)
35
+
36
+ if use_batchnorm == "inplace":
37
+ bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
38
+ relu = nn.Identity()
39
+
40
+ elif use_batchnorm and use_batchnorm != "inplace":
41
+ bn = nn.BatchNorm2d(out_channels)
42
+
43
+ else:
44
+ bn = nn.Identity()
45
+
46
+ super(Conv2dReLU, self).__init__(conv, bn, relu)
47
+
48
+
49
+ class SCSEModule(nn.Module):
50
+ def __init__(self, in_channels, reduction=16):
51
+ super().__init__()
52
+ self.cSE = nn.Sequential(
53
+ nn.AdaptiveAvgPool2d(1),
54
+ nn.Conv2d(in_channels, in_channels // reduction, 1),
55
+ nn.ReLU(inplace=True),
56
+ nn.Conv2d(in_channels // reduction, in_channels, 1),
57
+ nn.Sigmoid(),
58
+ )
59
+ self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
60
+
61
+ def forward(self, x):
62
+ return x * self.cSE(x) + x * self.sSE(x)
63
+
64
+
65
+ class ArgMax(nn.Module):
66
+ def __init__(self, dim=None):
67
+ super().__init__()
68
+ self.dim = dim
69
+
70
+ def forward(self, x):
71
+ return torch.argmax(x, dim=self.dim)
72
+
73
+
74
+ class Clamp(nn.Module):
75
+ def __init__(self, min=0, max=1):
76
+ super().__init__()
77
+ self.min, self.max = min, max
78
+
79
+ def forward(self, x):
80
+ return torch.clamp(x, self.min, self.max)
81
+
82
+
83
+ class Activation(nn.Module):
84
+ def __init__(self, name, **params):
85
+ super().__init__()
86
+
87
+ if name is None or name == "identity":
88
+ self.activation = nn.Identity(**params)
89
+ elif name == "sigmoid":
90
+ self.activation = nn.Sigmoid()
91
+ elif name == "relu":
92
+ self.activation = nn.ReLU(inplace=True)
93
+ elif name == "softmax2d":
94
+ self.activation = nn.Softmax(dim=1, **params)
95
+ elif name == "softmax":
96
+ self.activation = nn.Softmax(**params)
97
+ elif name == "logsoftmax":
98
+ self.activation = nn.LogSoftmax(**params)
99
+ elif name == "tanh":
100
+ self.activation = nn.Tanh()
101
+ elif name == "argmax":
102
+ self.activation = ArgMax(**params)
103
+ elif name == "argmax2d":
104
+ self.activation = ArgMax(dim=1, **params)
105
+ elif name == "clamp":
106
+ self.activation = Clamp(**params)
107
+ elif callable(name):
108
+ self.activation = name(**params)
109
+ else:
110
+ raise ValueError(
111
+ f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/"
112
+ f"argmax/argmax2d/clamp/None; got {name}"
113
+ )
114
+
115
+ def forward(self, x):
116
+ return self.activation(x)
117
+
118
+
119
+ class Attention(nn.Module):
120
+ def __init__(self, name, **params):
121
+ super().__init__()
122
+
123
+ if name is None:
124
+ self.attention = nn.Identity(**params)
125
+ elif name == "scse":
126
+ self.attention = SCSEModule(**params)
127
+ else:
128
+ raise ValueError("Attention {} is not implemented".format(name))
129
+
130
+ def forward(self, x):
131
+ return self.attention(x)
feature_extractor_models/decoders/__init__.py ADDED
File without changes
feature_extractor_models/decoders/deeplabv3/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import DeepLabV3, DeepLabV3Plus
2
+
3
+ __all__ = ["DeepLabV3", "DeepLabV3Plus"]
feature_extractor_models/decoders/deeplabv3/decoder.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) Soumith Chintala 2016,
5
+ All rights reserved.
6
+
7
+ Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice, this
11
+ list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice,
14
+ this list of conditions and the following disclaimer in the documentation
15
+ and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from
19
+ this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import torch
34
+ from torch import nn
35
+ from torch.nn import functional as F
36
+
37
+ __all__ = ["DeepLabV3Decoder"]
38
+
39
+
40
+ class DeepLabV3Decoder(nn.Sequential):
41
+ def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
42
+ super().__init__(
43
+ ASPP(in_channels, out_channels, atrous_rates),
44
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
45
+ nn.BatchNorm2d(out_channels),
46
+ nn.ReLU(),
47
+ )
48
+ self.out_channels = out_channels
49
+
50
+ def forward(self, *features):
51
+ return super().forward(features[-1])
52
+
53
+
54
+ class DeepLabV3PlusDecoder(nn.Module):
55
+ def __init__(
56
+ self,
57
+ encoder_channels,
58
+ out_channels=256,
59
+ atrous_rates=(12, 24, 36),
60
+ output_stride=16,
61
+ ):
62
+ super().__init__()
63
+ if output_stride not in {8, 16}:
64
+ raise ValueError(
65
+ "Output stride should be 8 or 16, got {}.".format(output_stride)
66
+ )
67
+
68
+ self.out_channels = out_channels
69
+ self.output_stride = output_stride
70
+
71
+ self.aspp = nn.Sequential(
72
+ ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
73
+ SeparableConv2d(
74
+ out_channels, out_channels, kernel_size=3, padding=1, bias=False
75
+ ),
76
+ nn.BatchNorm2d(out_channels),
77
+ nn.ReLU(),
78
+ )
79
+
80
+ scale_factor = 2 if output_stride == 8 else 4
81
+ self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
82
+
83
+ highres_in_channels = encoder_channels[-4]
84
+ highres_out_channels = 48 # proposed by authors of paper
85
+ self.block1 = nn.Sequential(
86
+ nn.Conv2d(
87
+ highres_in_channels, highres_out_channels, kernel_size=1, bias=False
88
+ ),
89
+ nn.BatchNorm2d(highres_out_channels),
90
+ nn.ReLU(),
91
+ )
92
+ self.block2 = nn.Sequential(
93
+ SeparableConv2d(
94
+ highres_out_channels + out_channels,
95
+ out_channels,
96
+ kernel_size=3,
97
+ padding=1,
98
+ bias=False,
99
+ ),
100
+ nn.BatchNorm2d(out_channels),
101
+ nn.ReLU(),
102
+ )
103
+
104
+ def forward(self, *features):
105
+ aspp_features = self.aspp(features[-1])
106
+ aspp_features = self.up(aspp_features)
107
+ high_res_features = self.block1(features[-4])
108
+ concat_features = torch.cat([aspp_features, high_res_features], dim=1)
109
+ fused_features = self.block2(concat_features)
110
+ return fused_features
111
+
112
+
113
+ class ASPPConv(nn.Sequential):
114
+ def __init__(self, in_channels, out_channels, dilation):
115
+ super().__init__(
116
+ nn.Conv2d(
117
+ in_channels,
118
+ out_channels,
119
+ kernel_size=3,
120
+ padding=dilation,
121
+ dilation=dilation,
122
+ bias=False,
123
+ ),
124
+ nn.BatchNorm2d(out_channels),
125
+ nn.ReLU(),
126
+ )
127
+
128
+
129
+ class ASPPSeparableConv(nn.Sequential):
130
+ def __init__(self, in_channels, out_channels, dilation):
131
+ super().__init__(
132
+ SeparableConv2d(
133
+ in_channels,
134
+ out_channels,
135
+ kernel_size=3,
136
+ padding=dilation,
137
+ dilation=dilation,
138
+ bias=False,
139
+ ),
140
+ nn.BatchNorm2d(out_channels),
141
+ nn.ReLU(),
142
+ )
143
+
144
+
145
+ class ASPPPooling(nn.Sequential):
146
+ def __init__(self, in_channels, out_channels):
147
+ super().__init__(
148
+ nn.AdaptiveAvgPool2d(1),
149
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
150
+ nn.BatchNorm2d(out_channels),
151
+ nn.ReLU(),
152
+ )
153
+
154
+ def forward(self, x):
155
+ size = x.shape[-2:]
156
+ for mod in self:
157
+ x = mod(x)
158
+ return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
159
+
160
+
161
+ class ASPP(nn.Module):
162
+ def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
163
+ super(ASPP, self).__init__()
164
+ modules = []
165
+ modules.append(
166
+ nn.Sequential(
167
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
168
+ nn.BatchNorm2d(out_channels),
169
+ nn.ReLU(),
170
+ )
171
+ )
172
+
173
+ rate1, rate2, rate3 = tuple(atrous_rates)
174
+ ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
175
+
176
+ modules.append(ASPPConvModule(in_channels, out_channels, rate1))
177
+ modules.append(ASPPConvModule(in_channels, out_channels, rate2))
178
+ modules.append(ASPPConvModule(in_channels, out_channels, rate3))
179
+ modules.append(ASPPPooling(in_channels, out_channels))
180
+
181
+ self.convs = nn.ModuleList(modules)
182
+
183
+ self.project = nn.Sequential(
184
+ nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
185
+ nn.BatchNorm2d(out_channels),
186
+ nn.ReLU(),
187
+ nn.Dropout(0.5),
188
+ )
189
+
190
+ def forward(self, x):
191
+ res = []
192
+ for conv in self.convs:
193
+ res.append(conv(x))
194
+ res = torch.cat(res, dim=1)
195
+ return self.project(res)
196
+
197
+
198
+ class SeparableConv2d(nn.Sequential):
199
+ def __init__(
200
+ self,
201
+ in_channels,
202
+ out_channels,
203
+ kernel_size,
204
+ stride=1,
205
+ padding=0,
206
+ dilation=1,
207
+ bias=True,
208
+ ):
209
+ dephtwise_conv = nn.Conv2d(
210
+ in_channels,
211
+ in_channels,
212
+ kernel_size,
213
+ stride=stride,
214
+ padding=padding,
215
+ dilation=dilation,
216
+ groups=in_channels,
217
+ bias=False,
218
+ )
219
+ pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
220
+ super().__init__(dephtwise_conv, pointwise_conv)
feature_extractor_models/decoders/deeplabv3/model.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from feature_extractor_models.base import (
4
+ SegmentationModel,
5
+ SegmentationHead,
6
+ ClassificationHead,
7
+ )
8
+ from feature_extractor_models.encoders import get_encoder
9
+ from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
10
+
11
+
12
+ class DeepLabV3(SegmentationModel):
13
+ """DeepLabV3_ implementation from "Rethinking Atrous Convolution for Semantic Image Segmentation"
14
+
15
+ Args:
16
+ encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
17
+ to extract features of different spatial resolution
18
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
19
+ two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
20
+ with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
21
+ Default is 5
22
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
23
+ other pretrained weights (see table with available weights for each encoder_name)
24
+ decoder_channels: A number of convolution filters in ASPP module. Default is 256
25
+ in_channels: A number of input channels for the model, default is 3 (RGB images)
26
+ classes: A number of classes for output mask (or you can think as a number of channels of output mask)
27
+ activation: An activation function to apply after the final convolution layer.
28
+ Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
29
+ **callable** and **None**.
30
+ Default is **None**
31
+ upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
32
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
33
+ on top of encoder if **aux_params** is not **None** (default). Supported params:
34
+ - classes (int): A number of classes
35
+ - pooling (str): One of "max", "avg". Default is "avg"
36
+ - dropout (float): Dropout factor in [0, 1)
37
+ - activation (str): An activation function to apply "sigmoid"/"softmax"
38
+ (could be **None** to return logits)
39
+ Returns:
40
+ ``torch.nn.Module``: **DeepLabV3**
41
+
42
+ .. _DeeplabV3:
43
+ https://arxiv.org/abs/1706.05587
44
+
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ encoder_name: str = "resnet34",
50
+ encoder_depth: int = 5,
51
+ encoder_weights: Optional[str] = "imagenet",
52
+ decoder_channels: int = 256,
53
+ in_channels: int = 3,
54
+ classes: int = 1,
55
+ activation: Optional[str] = None,
56
+ upsampling: int = 8,
57
+ aux_params: Optional[dict] = None,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.encoder = get_encoder(
62
+ encoder_name,
63
+ in_channels=in_channels,
64
+ depth=encoder_depth,
65
+ weights=encoder_weights,
66
+ output_stride=8,
67
+ )
68
+
69
+ self.decoder = DeepLabV3Decoder(
70
+ in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels
71
+ )
72
+
73
+ self.segmentation_head = SegmentationHead(
74
+ in_channels=self.decoder.out_channels,
75
+ out_channels=classes,
76
+ activation=activation,
77
+ kernel_size=1,
78
+ upsampling=upsampling,
79
+ )
80
+
81
+ if aux_params is not None:
82
+ self.classification_head = ClassificationHead(
83
+ in_channels=self.encoder.out_channels[-1], **aux_params
84
+ )
85
+ else:
86
+ self.classification_head = None
87
+
88
+
89
+ class DeepLabV3Plus(SegmentationModel):
90
+ """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
91
+ Convolution for Semantic Image Segmentation"
92
+
93
+ Args:
94
+ encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
95
+ to extract features of different spatial resolution
96
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
97
+ two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
98
+ with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
99
+ Default is 5
100
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
101
+ other pretrained weights (see table with available weights for each encoder_name)
102
+ encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
103
+ decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
104
+ decoder_channels: A number of convolution filters in ASPP module. Default is 256
105
+ in_channels: A number of input channels for the model, default is 3 (RGB images)
106
+ classes: A number of classes for output mask (or you can think as a number of channels of output mask)
107
+ activation: An activation function to apply after the final convolution layer.
108
+ Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
109
+ **callable** and **None**.
110
+ Default is **None**
111
+ upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
112
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
113
+ on top of encoder if **aux_params** is not **None** (default). Supported params:
114
+ - classes (int): A number of classes
115
+ - pooling (str): One of "max", "avg". Default is "avg"
116
+ - dropout (float): Dropout factor in [0, 1)
117
+ - activation (str): An activation function to apply "sigmoid"/"softmax"
118
+ (could be **None** to return logits)
119
+ Returns:
120
+ ``torch.nn.Module``: **DeepLabV3Plus**
121
+
122
+ Reference:
123
+ https://arxiv.org/abs/1802.02611v3
124
+
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ encoder_name: str = "resnet34",
130
+ encoder_depth: int = 5,
131
+ encoder_weights: Optional[str] = "imagenet",
132
+ encoder_output_stride: int = 16,
133
+ decoder_channels: int = 256,
134
+ decoder_atrous_rates: tuple = (12, 24, 36),
135
+ in_channels: int = 3,
136
+ classes: int = 1,
137
+ activation: Optional[str] = None,
138
+ upsampling: int = 4,
139
+ aux_params: Optional[dict] = None,
140
+ ):
141
+ super().__init__()
142
+
143
+ if encoder_output_stride not in [8, 16]:
144
+ raise ValueError(
145
+ "Encoder output stride should be 8 or 16, got {}".format(
146
+ encoder_output_stride
147
+ )
148
+ )
149
+
150
+ self.encoder = get_encoder(
151
+ encoder_name,
152
+ in_channels=in_channels,
153
+ depth=encoder_depth,
154
+ weights=encoder_weights,
155
+ output_stride=encoder_output_stride,
156
+ )
157
+
158
+ self.decoder = DeepLabV3PlusDecoder(
159
+ encoder_channels=self.encoder.out_channels,
160
+ out_channels=decoder_channels,
161
+ atrous_rates=decoder_atrous_rates,
162
+ output_stride=encoder_output_stride,
163
+ )
164
+
165
+ self.segmentation_head = SegmentationHead(
166
+ in_channels=self.decoder.out_channels,
167
+ out_channels=classes,
168
+ activation=activation,
169
+ kernel_size=1,
170
+ upsampling=upsampling,
171
+ )
172
+
173
+ if aux_params is not None:
174
+ self.classification_head = ClassificationHead(
175
+ in_channels=self.encoder.out_channels[-1], **aux_params
176
+ )
177
+ else:
178
+ self.classification_head = None
feature_extractor_models/decoders/fpn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import FPN
2
+
3
+ __all__ = ["FPN"]
feature_extractor_models/decoders/fpn/decoder.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Conv3x3GNReLU(nn.Module):
7
+ def __init__(self, in_channels, out_channels, upsample=False):
8
+ super().__init__()
9
+ self.upsample = upsample
10
+ self.block = nn.Sequential(
11
+ nn.Conv2d(
12
+ in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
13
+ ),
14
+ nn.GroupNorm(32, out_channels),
15
+ nn.ReLU(inplace=True),
16
+ )
17
+
18
+ def forward(self, x):
19
+ x = self.block(x)
20
+ if self.upsample:
21
+ x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
22
+ return x
23
+
24
+
25
+ class FPNBlock(nn.Module):
26
+ def __init__(self, pyramid_channels, skip_channels):
27
+ super().__init__()
28
+ self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
29
+
30
+ def forward(self, x, skip=None):
31
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
32
+ skip = self.skip_conv(skip)
33
+ x = x + skip
34
+ return x
35
+
36
+
37
+ class SegmentationBlock(nn.Module):
38
+ def __init__(self, in_channels, out_channels, n_upsamples=0):
39
+ super().__init__()
40
+
41
+ blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
42
+
43
+ if n_upsamples > 1:
44
+ for _ in range(1, n_upsamples):
45
+ blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
46
+
47
+ self.block = nn.Sequential(*blocks)
48
+
49
+ def forward(self, x):
50
+ return self.block(x)
51
+
52
+
53
+ class MergeBlock(nn.Module):
54
+ def __init__(self, policy):
55
+ super().__init__()
56
+ if policy not in ["add", "cat"]:
57
+ raise ValueError(
58
+ "`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)
59
+ )
60
+ self.policy = policy
61
+
62
+ def forward(self, x):
63
+ if self.policy == "add":
64
+ return sum(x)
65
+ elif self.policy == "cat":
66
+ return torch.cat(x, dim=1)
67
+ else:
68
+ raise ValueError(
69
+ "`merge_policy` must be one of: ['add', 'cat'], got {}".format(
70
+ self.policy
71
+ )
72
+ )
73
+
74
+
75
+ class FPNDecoder(nn.Module):
76
+ def __init__(
77
+ self,
78
+ encoder_channels,
79
+ encoder_depth=5,
80
+ pyramid_channels=256,
81
+ segmentation_channels=128,
82
+ dropout=0.2,
83
+ merge_policy="add",
84
+ ):
85
+ super().__init__()
86
+
87
+ self.out_channels = (
88
+ segmentation_channels
89
+ if merge_policy == "add"
90
+ else segmentation_channels * 4
91
+ )
92
+ if encoder_depth < 3:
93
+ raise ValueError(
94
+ "Encoder depth for FPN decoder cannot be less than 3, got {}.".format(
95
+ encoder_depth
96
+ )
97
+ )
98
+
99
+ encoder_channels = encoder_channels[::-1]
100
+ encoder_channels = encoder_channels[: encoder_depth + 1]
101
+
102
+ self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
103
+ self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
104
+ self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
105
+ self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
106
+
107
+ self.seg_blocks = nn.ModuleList(
108
+ [
109
+ SegmentationBlock(
110
+ pyramid_channels, segmentation_channels, n_upsamples=n_upsamples
111
+ )
112
+ for n_upsamples in [3, 2, 1, 0]
113
+ ]
114
+ )
115
+
116
+ self.merge = MergeBlock(merge_policy)
117
+ self.dropout = nn.Dropout2d(p=dropout, inplace=True)
118
+
119
+ def forward(self, *features):
120
+ c2, c3, c4, c5 = features[-4:]
121
+
122
+ p5 = self.p5(c5)
123
+ p4 = self.p4(p5, c4)
124
+ p3 = self.p3(p4, c3)
125
+ p2 = self.p2(p3, c2)
126
+
127
+ feature_pyramid = [
128
+ seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])
129
+ ]
130
+ x = self.merge(feature_pyramid)
131
+ x = self.dropout(x)
132
+
133
+ return x
feature_extractor_models/decoders/fpn/model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from feature_extractor_models.base import (
4
+ SegmentationModel,
5
+ SegmentationHead,
6
+ ClassificationHead,
7
+ )
8
+ from feature_extractor_models.encoders import get_encoder
9
+ from .decoder import FPNDecoder
10
+
11
+
12
+ class FPN(SegmentationModel):
13
+ """FPN_ is a fully convolution neural network for image semantic segmentation.
14
+
15
+ Args:
16
+ encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
17
+ to extract features of different spatial resolution
18
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
19
+ two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
20
+ with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
21
+ Default is 5
22
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
23
+ other pretrained weights (see table with available weights for each encoder_name)
24
+ decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_
25
+ decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_
26
+ decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add**
27
+ and **cat**
28
+ decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_
29
+ in_channels: A number of input channels for the model, default is 3 (RGB images)
30
+ classes: A number of classes for output mask (or you can think as a number of channels of output mask)
31
+ activation: An activation function to apply after the final convolution layer.
32
+ Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
33
+ **callable** and **None**.
34
+ Default is **None**
35
+ upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
36
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
37
+ on top of encoder if **aux_params** is not **None** (default). Supported params:
38
+ - classes (int): A number of classes
39
+ - pooling (str): One of "max", "avg". Default is "avg"
40
+ - dropout (float): Dropout factor in [0, 1)
41
+ - activation (str): An activation function to apply "sigmoid"/"softmax"
42
+ (could be **None** to return logits)
43
+
44
+ Returns:
45
+ ``torch.nn.Module``: **FPN**
46
+
47
+ .. _FPN:
48
+ http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
49
+
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ encoder_name: str = "resnet34",
55
+ encoder_depth: int = 5,
56
+ encoder_weights: Optional[str] = "imagenet",
57
+ decoder_pyramid_channels: int = 256,
58
+ decoder_segmentation_channels: int = 128,
59
+ decoder_merge_policy: str = "add",
60
+ decoder_dropout: float = 0.2,
61
+ in_channels: int = 3,
62
+ classes: int = 1,
63
+ activation: Optional[str] = None,
64
+ upsampling: int = 4,
65
+ aux_params: Optional[dict] = None,
66
+ ):
67
+ super().__init__()
68
+
69
+ # validate input params
70
+ if encoder_name.startswith("mit_b") and encoder_depth != 5:
71
+ raise ValueError(
72
+ "Encoder {} support only encoder_depth=5".format(encoder_name)
73
+ )
74
+
75
+ self.encoder = get_encoder(
76
+ encoder_name,
77
+ in_channels=in_channels,
78
+ depth=encoder_depth,
79
+ weights=encoder_weights,
80
+ )
81
+
82
+ self.decoder = FPNDecoder(
83
+ encoder_channels=self.encoder.out_channels,
84
+ encoder_depth=encoder_depth,
85
+ pyramid_channels=decoder_pyramid_channels,
86
+ segmentation_channels=decoder_segmentation_channels,
87
+ dropout=decoder_dropout,
88
+ merge_policy=decoder_merge_policy,
89
+ )
90
+
91
+ self.segmentation_head = SegmentationHead(
92
+ in_channels=self.decoder.out_channels,
93
+ out_channels=classes,
94
+ activation=activation,
95
+ kernel_size=1,
96
+ upsampling=upsampling,
97
+ )
98
+
99
+ if aux_params is not None:
100
+ self.classification_head = ClassificationHead(
101
+ in_channels=self.encoder.out_channels[-1], **aux_params
102
+ )
103
+ else:
104
+ self.classification_head = None
105
+
106
+ self.name = "fpn-{}".format(encoder_name)
107
+ self.initialize()
feature_extractor_models/decoders/lightfpn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import LightFPN
2
+
3
+ __all__ = ["LightFPN"]
feature_extractor_models/decoders/lightfpn/decoder.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Conv3x3GNReLU(nn.Module):
7
+ def __init__(self, in_channels, out_channels, upsample=False):
8
+ super().__init__()
9
+ self.upsample = upsample
10
+ self.block = nn.Sequential(
11
+ nn.Conv2d(
12
+ in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
13
+ ),
14
+ nn.GroupNorm(32, out_channels),
15
+ nn.ReLU(inplace=True),
16
+ )
17
+
18
+ def forward(self, x):
19
+ x = self.block(x)
20
+ if self.upsample:
21
+ x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
22
+ return x
23
+
24
+
25
+ class DepthwiseSeparableConv2d(nn.Module):
26
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
27
+ super().__init__()
28
+ self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)
29
+ self.pointwise = nn.Conv2d(in_channels, out_channels, 1)
30
+
31
+ def forward(self, x):
32
+ x = self.depthwise(x)
33
+ x = self.pointwise(x)
34
+ return x
35
+
36
+ class LightFPNBlock(nn.Module):
37
+ def __init__(self, pyramid_channels, skip_channels):
38
+ super().__init__()
39
+ self.skip_conv = DepthwiseSeparableConv2d(skip_channels, pyramid_channels, kernel_size=1)
40
+
41
+ def forward(self, x, skip=None):
42
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
43
+ skip = self.skip_conv(skip)
44
+ x = x + skip
45
+ return x
46
+
47
+
48
+ class SegmentationBlock(nn.Module):
49
+ def __init__(self, in_channels, out_channels, n_upsamples=0):
50
+ super().__init__()
51
+
52
+ blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
53
+
54
+ if n_upsamples > 1:
55
+ for _ in range(1, n_upsamples):
56
+ blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
57
+
58
+ self.block = nn.Sequential(*blocks)
59
+
60
+ def forward(self, x):
61
+ return self.block(x)
62
+
63
+
64
+ class MergeBlock(nn.Module):
65
+ def __init__(self, policy):
66
+ super().__init__()
67
+ if policy not in ["add", "cat"]:
68
+ raise ValueError(
69
+ "`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)
70
+ )
71
+ self.policy = policy
72
+
73
+ def forward(self, x):
74
+ if self.policy == "add":
75
+ return sum(x)
76
+ elif self.policy == "cat":
77
+ return torch.cat(x, dim=1)
78
+ else:
79
+ raise ValueError(
80
+ "`merge_policy` must be one of: ['add', 'cat'], got {}".format(
81
+ self.policy
82
+ )
83
+ )
84
+
85
+
86
+ class FPNDecoder(nn.Module):
87
+ def __init__(
88
+ self,
89
+ encoder_channels,
90
+ encoder_depth=5,
91
+ pyramid_channels=256,
92
+ segmentation_channels=128,
93
+ dropout=0.2,
94
+ merge_policy="add",
95
+ ):
96
+ super().__init__()
97
+
98
+ self.out_channels = (
99
+ segmentation_channels
100
+ if merge_policy == "add"
101
+ else segmentation_channels * 4
102
+ )
103
+ if encoder_depth < 3:
104
+ raise ValueError(
105
+ "Encoder depth for FPN decoder cannot be less than 3, got {}.".format(
106
+ encoder_depth
107
+ )
108
+ )
109
+
110
+ encoder_channels = encoder_channels[::-1]
111
+ encoder_channels = encoder_channels[: encoder_depth + 1]
112
+
113
+ self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
114
+ self.p4 = LightFPNBlock(pyramid_channels, encoder_channels[1])
115
+ self.p3 = LightFPNBlock(pyramid_channels, encoder_channels[2])
116
+ self.p2 = LightFPNBlock(pyramid_channels, encoder_channels[3])
117
+
118
+ self.seg_blocks = nn.ModuleList(
119
+ [
120
+ SegmentationBlock(
121
+ pyramid_channels, segmentation_channels, n_upsamples=n_upsamples
122
+ )
123
+ for n_upsamples in [3, 2, 1, 0]
124
+ ]
125
+ )
126
+
127
+ self.merge = MergeBlock(merge_policy)
128
+ self.dropout = nn.Dropout2d(p=dropout, inplace=True)
129
+
130
+ def forward(self, *features):
131
+ c2, c3, c4, c5 = features[-4:]
132
+
133
+ p5 = self.p5(c5)
134
+ p4 = self.p4(p5, c4)
135
+ p3 = self.p3(p4, c3)
136
+ p2 = self.p2(p3, c2)
137
+
138
+ feature_pyramid = [
139
+ seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])
140
+ ]
141
+ x = self.merge(feature_pyramid)
142
+ x = self.dropout(x)
143
+
144
+ return x
feature_extractor_models/decoders/lightfpn/model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from feature_extractor_models.base import (
4
+ SegmentationModel,
5
+ SegmentationHead,
6
+ ClassificationHead,
7
+ )
8
+ from feature_extractor_models.encoders import get_encoder
9
+ from .decoder import FPNDecoder
10
+
11
+
12
+ class LightFPN(SegmentationModel):
13
+ """FPN_ is a fully convolution neural network for image semantic segmentation.
14
+
15
+ Args:
16
+ encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
17
+ to extract features of different spatial resolution
18
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
19
+ two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
20
+ with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
21
+ Default is 5
22
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
23
+ other pretrained weights (see table with available weights for each encoder_name)
24
+ decoder_pyramid_channels: A number of convolution filters in Feature Pyramid of FPN_
25
+ decoder_segmentation_channels: A number of convolution filters in segmentation blocks of FPN_
26
+ decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add**
27
+ and **cat**
28
+ decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_
29
+ in_channels: A number of input channels for the model, default is 3 (RGB images)
30
+ classes: A number of classes for output mask (or you can think as a number of channels of output mask)
31
+ activation: An activation function to apply after the final convolution layer.
32
+ Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
33
+ **callable** and **None**.
34
+ Default is **None**
35
+ upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
36
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
37
+ on top of encoder if **aux_params** is not **None** (default). Supported params:
38
+ - classes (int): A number of classes
39
+ - pooling (str): One of "max", "avg". Default is "avg"
40
+ - dropout (float): Dropout factor in [0, 1)
41
+ - activation (str): An activation function to apply "sigmoid"/"softmax"
42
+ (could be **None** to return logits)
43
+
44
+ Returns:
45
+ ``torch.nn.Module``: **FPN**
46
+
47
+ .. _FPN:
48
+ http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
49
+
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ encoder_name: str = "resnet34",
55
+ encoder_depth: int = 5,
56
+ encoder_weights: Optional[str] = "imagenet",
57
+ decoder_pyramid_channels: int = 256,
58
+ decoder_segmentation_channels: int = 128,
59
+ decoder_merge_policy: str = "add",
60
+ decoder_dropout: float = 0.2,
61
+ in_channels: int = 3,
62
+ classes: int = 1,
63
+ activation: Optional[str] = None,
64
+ upsampling: int = 4,
65
+ aux_params: Optional[dict] = None,
66
+ ):
67
+ super().__init__()
68
+
69
+ # validate input params
70
+ if encoder_name.startswith("mit_b") and encoder_depth != 5:
71
+ raise ValueError(
72
+ "Encoder {} support only encoder_depth=5".format(encoder_name)
73
+ )
74
+
75
+ self.encoder = get_encoder(
76
+ encoder_name,
77
+ in_channels=in_channels,
78
+ depth=encoder_depth,
79
+ weights=encoder_weights,
80
+ )
81
+
82
+ self.decoder = FPNDecoder(
83
+ encoder_channels=self.encoder.out_channels,
84
+ encoder_depth=encoder_depth,
85
+ pyramid_channels=decoder_pyramid_channels,
86
+ segmentation_channels=decoder_segmentation_channels,
87
+ dropout=decoder_dropout,
88
+ merge_policy=decoder_merge_policy,
89
+ )
90
+
91
+ self.segmentation_head = SegmentationHead(
92
+ in_channels=self.decoder.out_channels,
93
+ out_channels=classes,
94
+ activation=activation,
95
+ kernel_size=1,
96
+ upsampling=upsampling,
97
+ )
98
+
99
+ if aux_params is not None:
100
+ self.classification_head = ClassificationHead(
101
+ in_channels=self.encoder.out_channels[-1], **aux_params
102
+ )
103
+ else:
104
+ self.classification_head = None
105
+
106
+ self.name = "fpn-{}".format(encoder_name)
107
+ self.initialize()
feature_extractor_models/decoders/linknet/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import Linknet
2
+
3
+ __all__ = ["Linknet"]
feature_extractor_models/decoders/linknet/decoder.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from feature_extractor_models.base import modules
4
+
5
+
6
+ class TransposeX2(nn.Sequential):
7
+ def __init__(self, in_channels, out_channels, use_batchnorm=True):
8
+ super().__init__()
9
+ layers = [
10
+ nn.ConvTranspose2d(
11
+ in_channels, out_channels, kernel_size=4, stride=2, padding=1
12
+ ),
13
+ nn.ReLU(inplace=True),
14
+ ]
15
+
16
+ if use_batchnorm:
17
+ layers.insert(1, nn.BatchNorm2d(out_channels))
18
+
19
+ super().__init__(*layers)
20
+
21
+
22
+ class DecoderBlock(nn.Module):
23
+ def __init__(self, in_channels, out_channels, use_batchnorm=True):
24
+ super().__init__()
25
+
26
+ self.block = nn.Sequential(
27
+ modules.Conv2dReLU(
28
+ in_channels,
29
+ in_channels // 4,
30
+ kernel_size=1,
31
+ use_batchnorm=use_batchnorm,
32
+ ),
33
+ TransposeX2(
34
+ in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
35
+ ),
36
+ modules.Conv2dReLU(
37
+ in_channels // 4,
38
+ out_channels,
39
+ kernel_size=1,
40
+ use_batchnorm=use_batchnorm,
41
+ ),
42
+ )
43
+
44
+ def forward(self, x, skip=None):
45
+ x = self.block(x)
46
+ if skip is not None:
47
+ x = x + skip
48
+ return x
49
+
50
+
51
+ class LinknetDecoder(nn.Module):
52
+ def __init__(
53
+ self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True
54
+ ):
55
+ super().__init__()
56
+
57
+ # remove first skip
58
+ encoder_channels = encoder_channels[1:]
59
+ # reverse channels to start from head of encoder
60
+ encoder_channels = encoder_channels[::-1]
61
+
62
+ channels = list(encoder_channels) + [prefinal_channels]
63
+
64
+ self.blocks = nn.ModuleList(
65
+ [
66
+ DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
67
+ for i in range(n_blocks)
68
+ ]
69
+ )
70
+
71
+ def forward(self, *features):
72
+ features = features[1:] # remove first skip
73
+ features = features[::-1] # reverse channels to start from head of encoder
74
+
75
+ x = features[0]
76
+ skips = features[1:]
77
+
78
+ for i, decoder_block in enumerate(self.blocks):
79
+ skip = skips[i] if i < len(skips) else None
80
+ x = decoder_block(x, skip)
81
+
82
+ return x
feature_extractor_models/decoders/linknet/model.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ from feature_extractor_models.base import (
4
+ SegmentationHead,
5
+ SegmentationModel,
6
+ ClassificationHead,
7
+ )
8
+ from feature_extractor_models.encoders import get_encoder
9
+ from .decoder import LinknetDecoder
10
+
11
+
12
+ class Linknet(SegmentationModel):
13
+ """Linknet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
14
+ and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
15
+ resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *sum*
16
+ for fusing decoder blocks with skip connections.
17
+
18
+ Note:
19
+ This implementation by default has 4 skip connections (original - 3).
20
+
21
+ Args:
22
+ encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
23
+ to extract features of different spatial resolution
24
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
25
+ two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
26
+ with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
27
+ Default is 5
28
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
29
+ other pretrained weights (see table with available weights for each encoder_name)
30
+ decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
31
+ is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
32
+ Available options are **True, False, "inplace"**
33
+ in_channels: A number of input channels for the model, default is 3 (RGB images)
34
+ classes: A number of classes for output mask (or you can think as a number of channels of output mask)
35
+ activation: An activation function to apply after the final convolution layer.
36
+ Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
37
+ **callable** and **None**.
38
+ Default is **None**
39
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
40
+ on top of encoder if **aux_params** is not **None** (default). Supported params:
41
+ - classes (int): A number of classes
42
+ - pooling (str): One of "max", "avg". Default is "avg"
43
+ - dropout (float): Dropout factor in [0, 1)
44
+ - activation (str): An activation function to apply "sigmoid"/"softmax"
45
+ (could be **None** to return logits)
46
+
47
+ Returns:
48
+ ``torch.nn.Module``: **Linknet**
49
+
50
+ .. _Linknet:
51
+ https://arxiv.org/abs/1707.03718
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ encoder_name: str = "resnet34",
57
+ encoder_depth: int = 5,
58
+ encoder_weights: Optional[str] = "imagenet",
59
+ decoder_use_batchnorm: bool = True,
60
+ in_channels: int = 3,
61
+ classes: int = 1,
62
+ activation: Optional[Union[str, callable]] = None,
63
+ aux_params: Optional[dict] = None,
64
+ ):
65
+ super().__init__()
66
+
67
+ if encoder_name.startswith("mit_b"):
68
+ raise ValueError(
69
+ "Encoder `{}` is not supported for Linknet".format(encoder_name)
70
+ )
71
+
72
+ self.encoder = get_encoder(
73
+ encoder_name,
74
+ in_channels=in_channels,
75
+ depth=encoder_depth,
76
+ weights=encoder_weights,
77
+ )
78
+
79
+ self.decoder = LinknetDecoder(
80
+ encoder_channels=self.encoder.out_channels,
81
+ n_blocks=encoder_depth,
82
+ prefinal_channels=32,
83
+ use_batchnorm=decoder_use_batchnorm,
84
+ )
85
+
86
+ self.segmentation_head = SegmentationHead(
87
+ in_channels=32, out_channels=classes, activation=activation, kernel_size=1
88
+ )
89
+
90
+ if aux_params is not None:
91
+ self.classification_head = ClassificationHead(
92
+ in_channels=self.encoder.out_channels[-1], **aux_params
93
+ )
94
+ else:
95
+ self.classification_head = None
96
+
97
+ self.name = "link-{}".format(encoder_name)
98
+ self.initialize()
feature_extractor_models/decoders/manet/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import MAnet
2
+
3
+ __all__ = ["MAnet"]
feature_extractor_models/decoders/manet/decoder.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from feature_extractor_models.base import modules as md
6
+
7
+
8
+ class PAB(nn.Module):
9
+ def __init__(self, in_channels, out_channels, pab_channels=64):
10
+ super(PAB, self).__init__()
11
+ # Series of 1x1 conv to generate attention feature maps
12
+ self.pab_channels = pab_channels
13
+ self.in_channels = in_channels
14
+ self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
15
+ self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
16
+ self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
17
+ self.map_softmax = nn.Softmax(dim=1)
18
+ self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
19
+
20
+ def forward(self, x):
21
+ bsize = x.size()[0]
22
+ h = x.size()[2]
23
+ w = x.size()[3]
24
+ x_top = self.top_conv(x)
25
+ x_center = self.center_conv(x)
26
+ x_bottom = self.bottom_conv(x)
27
+
28
+ x_top = x_top.flatten(2)
29
+ x_center = x_center.flatten(2).transpose(1, 2)
30
+ x_bottom = x_bottom.flatten(2).transpose(1, 2)
31
+
32
+ sp_map = torch.matmul(x_center, x_top)
33
+ sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h * w, h * w)
34
+ sp_map = torch.matmul(sp_map, x_bottom)
35
+ sp_map = sp_map.reshape(bsize, self.in_channels, h, w)
36
+ x = x + sp_map
37
+ x = self.out_conv(x)
38
+ return x
39
+
40
+
41
+ class MFAB(nn.Module):
42
+ def __init__(
43
+ self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16
44
+ ):
45
+ # MFAB is just a modified version of SE-blocks, one for skip, one for input
46
+ super(MFAB, self).__init__()
47
+ self.hl_conv = nn.Sequential(
48
+ md.Conv2dReLU(
49
+ in_channels,
50
+ in_channels,
51
+ kernel_size=3,
52
+ padding=1,
53
+ use_batchnorm=use_batchnorm,
54
+ ),
55
+ md.Conv2dReLU(
56
+ in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm
57
+ ),
58
+ )
59
+ reduced_channels = max(1, skip_channels // reduction)
60
+ self.SE_ll = nn.Sequential(
61
+ nn.AdaptiveAvgPool2d(1),
62
+ nn.Conv2d(skip_channels, reduced_channels, 1),
63
+ nn.ReLU(inplace=True),
64
+ nn.Conv2d(reduced_channels, skip_channels, 1),
65
+ nn.Sigmoid(),
66
+ )
67
+ self.SE_hl = nn.Sequential(
68
+ nn.AdaptiveAvgPool2d(1),
69
+ nn.Conv2d(skip_channels, reduced_channels, 1),
70
+ nn.ReLU(inplace=True),
71
+ nn.Conv2d(reduced_channels, skip_channels, 1),
72
+ nn.Sigmoid(),
73
+ )
74
+ self.conv1 = md.Conv2dReLU(
75
+ skip_channels
76
+ + skip_channels, # we transform C-prime form high level to C from skip connection
77
+ out_channels,
78
+ kernel_size=3,
79
+ padding=1,
80
+ use_batchnorm=use_batchnorm,
81
+ )
82
+ self.conv2 = md.Conv2dReLU(
83
+ out_channels,
84
+ out_channels,
85
+ kernel_size=3,
86
+ padding=1,
87
+ use_batchnorm=use_batchnorm,
88
+ )
89
+
90
+ def forward(self, x, skip=None):
91
+ x = self.hl_conv(x)
92
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
93
+ attention_hl = self.SE_hl(x)
94
+ if skip is not None:
95
+ attention_ll = self.SE_ll(skip)
96
+ attention_hl = attention_hl + attention_ll
97
+ x = x * attention_hl
98
+ x = torch.cat([x, skip], dim=1)
99
+ x = self.conv1(x)
100
+ x = self.conv2(x)
101
+ return x
102
+
103
+
104
+ class DecoderBlock(nn.Module):
105
+ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True):
106
+ super().__init__()
107
+ self.conv1 = md.Conv2dReLU(
108
+ in_channels + skip_channels,
109
+ out_channels,
110
+ kernel_size=3,
111
+ padding=1,
112
+ use_batchnorm=use_batchnorm,
113
+ )
114
+ self.conv2 = md.Conv2dReLU(
115
+ out_channels,
116
+ out_channels,
117
+ kernel_size=3,
118
+ padding=1,
119
+ use_batchnorm=use_batchnorm,
120
+ )
121
+
122
+ def forward(self, x, skip=None):
123
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
124
+ if skip is not None:
125
+ x = torch.cat([x, skip], dim=1)
126
+ x = self.conv1(x)
127
+ x = self.conv2(x)
128
+ return x
129
+
130
+
131
+ class MAnetDecoder(nn.Module):
132
+ def __init__(
133
+ self,
134
+ encoder_channels,
135
+ decoder_channels,
136
+ n_blocks=5,
137
+ reduction=16,
138
+ use_batchnorm=True,
139
+ pab_channels=64,
140
+ ):
141
+ super().__init__()
142
+
143
+ if n_blocks != len(decoder_channels):
144
+ raise ValueError(
145
+ "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
146
+ n_blocks, len(decoder_channels)
147
+ )
148
+ )
149
+
150
+ # remove first skip with same spatial resolution
151
+ encoder_channels = encoder_channels[1:]
152
+
153
+ # reverse channels to start from head of encoder
154
+ encoder_channels = encoder_channels[::-1]
155
+
156
+ # computing blocks input and output channels
157
+ head_channels = encoder_channels[0]
158
+ in_channels = [head_channels] + list(decoder_channels[:-1])
159
+ skip_channels = list(encoder_channels[1:]) + [0]
160
+ out_channels = decoder_channels
161
+
162
+ self.center = PAB(head_channels, head_channels, pab_channels=pab_channels)
163
+
164
+ # combine decoder keyword arguments
165
+ kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
166
+ blocks = [
167
+ MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
168
+ if skip_ch > 0
169
+ else DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
170
+ for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
171
+ ]
172
+ # for the last we dont have skip connection -> use simple decoder block
173
+ self.blocks = nn.ModuleList(blocks)
174
+
175
+ def forward(self, *features):
176
+ features = features[1:] # remove first skip with same spatial resolution
177
+ features = features[::-1] # reverse channels to start from head of encoder
178
+
179
+ head = features[0]
180
+ skips = features[1:]
181
+
182
+ x = self.center(head)
183
+ for i, decoder_block in enumerate(self.blocks):
184
+ skip = skips[i] if i < len(skips) else None
185
+ x = decoder_block(x, skip)
186
+
187
+ return x
feature_extractor_models/decoders/manet/model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+
3
+ from feature_extractor_models.encoders import get_encoder
4
+ from feature_extractor_models.base import (
5
+ SegmentationModel,
6
+ SegmentationHead,
7
+ ClassificationHead,
8
+ )
9
+ from .decoder import MAnetDecoder
10
+
11
+
12
+ class MAnet(SegmentationModel):
13
+ """MAnet_ : Multi-scale Attention Net. The MA-Net can capture rich contextual dependencies based on
14
+ the attention mechanism, using two blocks:
15
+ - Position-wise Attention Block (PAB), which captures the spatial dependencies between pixels in a global view
16
+ - Multi-scale Fusion Attention Block (MFAB), which captures the channel dependencies between any feature map by
17
+ multi-scale semantic feature fusion
18
+
19
+ Args:
20
+ encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
21
+ to extract features of different spatial resolution
22
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
23
+ two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
24
+ with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
25
+ Default is 5
26
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
27
+ other pretrained weights (see table with available weights for each encoder_name)
28
+ decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
29
+ Length of the list should be the same as **encoder_depth**
30
+ decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
31
+ is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
32
+ Available options are **True, False, "inplace"**
33
+ decoder_pab_channels: A number of channels for PAB module in decoder.
34
+ Default is 64.
35
+ in_channels: A number of input channels for the model, default is 3 (RGB images)
36
+ classes: A number of classes for output mask (or you can think as a number of channels of output mask)
37
+ activation: An activation function to apply after the final convolution layer.
38
+ Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
39
+ **callable** and **None**.
40
+ Default is **None**
41
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
42
+ on top of encoder if **aux_params** is not **None** (default). Supported params:
43
+ - classes (int): A number of classes
44
+ - pooling (str): One of "max", "avg". Default is "avg"
45
+ - dropout (float): Dropout factor in [0, 1)
46
+ - activation (str): An activation function to apply "sigmoid"/"softmax"
47
+ (could be **None** to return logits)
48
+
49
+ Returns:
50
+ ``torch.nn.Module``: **MAnet**
51
+
52
+ .. _MAnet:
53
+ https://ieeexplore.ieee.org/abstract/document/9201310
54
+
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ encoder_name: str = "resnet34",
60
+ encoder_depth: int = 5,
61
+ encoder_weights: Optional[str] = "imagenet",
62
+ decoder_use_batchnorm: bool = True,
63
+ decoder_channels: List[int] = (256, 128, 64, 32, 16),
64
+ decoder_pab_channels: int = 64,
65
+ in_channels: int = 3,
66
+ classes: int = 1,
67
+ activation: Optional[Union[str, callable]] = None,
68
+ aux_params: Optional[dict] = None,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.encoder = get_encoder(
73
+ encoder_name,
74
+ in_channels=in_channels,
75
+ depth=encoder_depth,
76
+ weights=encoder_weights,
77
+ )
78
+
79
+ self.decoder = MAnetDecoder(
80
+ encoder_channels=self.encoder.out_channels,
81
+ decoder_channels=decoder_channels,
82
+ n_blocks=encoder_depth,
83
+ use_batchnorm=decoder_use_batchnorm,
84
+ pab_channels=decoder_pab_channels,
85
+ )
86
+
87
+ self.segmentation_head = SegmentationHead(
88
+ in_channels=decoder_channels[-1],
89
+ out_channels=classes,
90
+ activation=activation,
91
+ kernel_size=3,
92
+ )
93
+
94
+ if aux_params is not None:
95
+ self.classification_head = ClassificationHead(
96
+ in_channels=self.encoder.out_channels[-1], **aux_params
97
+ )
98
+ else:
99
+ self.classification_head = None
100
+
101
+ self.name = "manet-{}".format(encoder_name)
102
+ self.initialize()
feature_extractor_models/decoders/pan/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import PAN
2
+
3
+ __all__ = ["PAN"]