File size: 4,768 Bytes
3e8fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
9456d71
3e8fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84f14a2
3e8fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""

Documentation on Hugging Face: https://huggingface.co/docs/transformers/en/custom_models

"""

from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss
from transformers import PreTrainedModel
from monai.networks.nets import SwinUNETR

from magdi_segmentation_models_3d.models.swinunetrv2.configuration_swinvunetr2 import (
    SwinUNETRv2Config,
)


# @auto_docstring
class SwinUNETRv2PreTrainedModel(PreTrainedModel):
    config_class = SwinUNETRv2Config


# @auto_docstring
class SwinUNETRv2Model(SwinUNETRv2PreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.model = SwinUNETR(
            in_channels=config.in_channels,
            out_channels=config.out_channels,
            patch_size=config.patch_size,
            depths=config.depths,
            num_heads=config.num_heads,
            window_size=config.window_size,
            qkv_bias=config.qkv_bias,
            mlp_ratio=config.mlp_ratio,
            feature_size=config.feature_size,
            norm_name=config.norm_name,
            drop_rate=config.drop_rate,
            attn_drop_rate=config.attn_drop_rate,
            dropout_path_rate=config.dropout_path_rate,
            normalize=config.normalize,
            # norm_layer=config.norm_layer,
            patch_norm=config.patch_norm,
            use_checkpoint=config.use_checkpoint,
            spatial_dims=config.spatial_dims,
            downsample=config.downsample,
            use_v2=True,
        )

    def forward(self, tensor):
        return self.model(tensor)


# @auto_docstring
class SwinUNETRv2ForImageSegmentation(SwinUNETRv2PreTrainedModel):
    config_class = SwinUNETRv2Config

    def __init__(self, config):
        super().__init__(config)
        self.model = SwinUNETR(
            in_channels=config.in_channels,
            out_channels=config.out_channels,
            patch_size=config.patch_size,
            depths=config.depths,
            num_heads=config.num_heads,
            window_size=config.window_size,
            qkv_bias=config.qkv_bias,
            mlp_ratio=config.mlp_ratio,
            feature_size=config.feature_size,
            norm_name=config.norm_name,
            drop_rate=config.drop_rate,
            attn_drop_rate=config.attn_drop_rate,
            dropout_path_rate=config.dropout_path_rate,
            normalize=config.normalize,
            # norm_layer=config.norm_layer,
            patch_norm=config.patch_norm,
            use_checkpoint=config.use_checkpoint,
            spatial_dims=config.spatial_dims,
            downsample=config.downsample,
            use_v2=True,
        )

    def forward(self, tensor, train=False, roi_size=(128, 128, 128), sw_batch_size=1):

        criterion = DiceCELoss(to_onehot_y=True, softmax=True)

        image = tensor["image"]
        annotations = tensor["annotations"]

        if train:
            logits = self.model(image)
            loss = criterion(logits, annotations)
        else:
            logits = sliding_window_inference(
                tensor["image"],
                roi_size,
                sw_batch_size,
                self.model.forward,
            )
            loss = criterion(logits, annotations)

        return {
            "logits": logits,
            "loss": loss,
        }


# @auto_docstring
class SwinUNETRv2Backbone(SwinUNETRv2PreTrainedModel):
    config_class = SwinUNETRv2Config

    def __init__(self, config):
        super().__init__(config)
        self.swinViT = SwinUNETR(
            in_channels=config.in_channels,
            out_channels=config.out_channels,
            patch_size=config.patch_size,
            depths=config.depths,
            num_heads=config.num_heads,
            window_size=config.window_size,
            qkv_bias=config.qkv_bias,
            mlp_ratio=config.mlp_ratio,
            feature_size=config.feature_size,
            norm_name=config.norm_name,
            drop_rate=config.drop_rate,
            attn_drop_rate=config.attn_drop_rate,
            dropout_path_rate=config.dropout_path_rate,
            normalize=config.normalize,
            # norm_layer=config.norm_layer,
            patch_norm=config.patch_norm,
            use_checkpoint=config.use_checkpoint,
            spatial_dims=config.spatial_dims,
            downsample=config.downsample,
            use_v2=True,
        ).swinViT

    def forward(self, tensor):
        return self.model(tensor)


__all__ = [
    "SwinUNETRv2ForImageSegmentation",
    "SwinUNETRv2Model",
    "SwinUNETRv2PreTrainedModel",
    "SwinUNETRv2Backbone",
]