|
"""
|
|
Documentation on Hugging Face: https://huggingface.co/docs/transformers/en/custom_models
|
|
"""
|
|
|
|
from monai.inferers import sliding_window_inference
|
|
from monai.losses import DiceCELoss
|
|
from transformers import PreTrainedModel
|
|
from monai.networks.nets import SwinUNETR
|
|
|
|
from magdi_segmentation_models_3d.models.swinunetrv2.configuration_swinvunetr2 import (
|
|
SwinUNETRv2Config,
|
|
)
|
|
|
|
|
|
|
|
class SwinUNETRv2PreTrainedModel(PreTrainedModel):
|
|
config_class = SwinUNETRv2Config
|
|
|
|
|
|
|
|
class SwinUNETRv2Model(SwinUNETRv2PreTrainedModel):
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = SwinUNETR(
|
|
in_channels=config.in_channels,
|
|
out_channels=config.out_channels,
|
|
patch_size=config.patch_size,
|
|
depths=config.depths,
|
|
num_heads=config.num_heads,
|
|
window_size=config.window_size,
|
|
qkv_bias=config.qkv_bias,
|
|
mlp_ratio=config.mlp_ratio,
|
|
feature_size=config.feature_size,
|
|
norm_name=config.norm_name,
|
|
drop_rate=config.drop_rate,
|
|
attn_drop_rate=config.attn_drop_rate,
|
|
dropout_path_rate=config.dropout_path_rate,
|
|
normalize=config.normalize,
|
|
|
|
patch_norm=config.patch_norm,
|
|
use_checkpoint=config.use_checkpoint,
|
|
spatial_dims=config.spatial_dims,
|
|
downsample=config.downsample,
|
|
use_v2=True,
|
|
)
|
|
|
|
def forward(self, tensor):
|
|
return self.model(tensor)
|
|
|
|
|
|
|
|
class SwinUNETRv2ForImageSegmentation(SwinUNETRv2PreTrainedModel):
|
|
config_class = SwinUNETRv2Config
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = SwinUNETR(
|
|
in_channels=config.in_channels,
|
|
out_channels=config.out_channels,
|
|
patch_size=config.patch_size,
|
|
depths=config.depths,
|
|
num_heads=config.num_heads,
|
|
window_size=config.window_size,
|
|
qkv_bias=config.qkv_bias,
|
|
mlp_ratio=config.mlp_ratio,
|
|
feature_size=config.feature_size,
|
|
norm_name=config.norm_name,
|
|
drop_rate=config.drop_rate,
|
|
attn_drop_rate=config.attn_drop_rate,
|
|
dropout_path_rate=config.dropout_path_rate,
|
|
normalize=config.normalize,
|
|
|
|
patch_norm=config.patch_norm,
|
|
use_checkpoint=config.use_checkpoint,
|
|
spatial_dims=config.spatial_dims,
|
|
downsample=config.downsample,
|
|
use_v2=True,
|
|
)
|
|
|
|
def forward(self, tensor, train=False, roi_size=(128, 128, 128), sw_batch_size=1):
|
|
|
|
criterion = DiceCELoss(to_onehot_y=True, softmax=True)
|
|
|
|
image = tensor["image"]
|
|
annotations = tensor["annotations"]
|
|
|
|
if train:
|
|
logits = self.model(image)
|
|
loss = criterion(logits, annotations)
|
|
else:
|
|
logits = sliding_window_inference(
|
|
tensor["image"],
|
|
roi_size,
|
|
sw_batch_size,
|
|
self.model.forward,
|
|
)
|
|
loss = criterion(logits, annotations)
|
|
|
|
return {
|
|
"logits": logits,
|
|
"loss": loss,
|
|
}
|
|
|
|
|
|
|
|
class SwinUNETRv2Backbone(SwinUNETRv2PreTrainedModel):
|
|
config_class = SwinUNETRv2Config
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.swinViT = SwinUNETR(
|
|
in_channels=config.in_channels,
|
|
out_channels=config.out_channels,
|
|
patch_size=config.patch_size,
|
|
depths=config.depths,
|
|
num_heads=config.num_heads,
|
|
window_size=config.window_size,
|
|
qkv_bias=config.qkv_bias,
|
|
mlp_ratio=config.mlp_ratio,
|
|
feature_size=config.feature_size,
|
|
norm_name=config.norm_name,
|
|
drop_rate=config.drop_rate,
|
|
attn_drop_rate=config.attn_drop_rate,
|
|
dropout_path_rate=config.dropout_path_rate,
|
|
normalize=config.normalize,
|
|
|
|
patch_norm=config.patch_norm,
|
|
use_checkpoint=config.use_checkpoint,
|
|
spatial_dims=config.spatial_dims,
|
|
downsample=config.downsample,
|
|
use_v2=True,
|
|
).swinViT
|
|
|
|
def forward(self, tensor):
|
|
return self.model(tensor)
|
|
|
|
|
|
__all__ = [
|
|
"SwinUNETRv2ForImageSegmentation",
|
|
"SwinUNETRv2Model",
|
|
"SwinUNETRv2PreTrainedModel",
|
|
"SwinUNETRv2Backbone",
|
|
]
|
|
|