|
import torch |
|
import torch.nn as nn |
|
import segmentation_models_pytorch as smp |
|
from typing import Optional |
|
|
|
|
|
class SegFormer(nn.Module): |
|
""" |
|
SegFormer model for multi-class semantic segmentation. |
|
|
|
Default setup targets RGB (3 bands), but you can set `in_channels` to support |
|
multispectral inputs (e.g., 13 for Sentinel-2 L1C). Outputs raw logits with |
|
shape (B, num_classes, H, W). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
encoder_name: str = "mit_b4", |
|
encoder_weights: Optional[str] = "imagenet", |
|
in_channels: int = 3, |
|
num_classes: int = 4, |
|
freeze_encoder: bool = False, |
|
) -> None: |
|
""" |
|
Args: |
|
encoder_name: TIMM encoder name (e.g., 'mit_b0'...'mit_b5', default 'mit_b4'). |
|
encoder_weights: Pretrained weights name (typically 'imagenet' or None). |
|
in_channels: Number of input channels (3 for RGB, 13 for Sentinel-2, etc.). |
|
num_classes: Number of output classes for segmentation. |
|
freeze_encoder: If True, freezes encoder parameters during training. |
|
""" |
|
super().__init__() |
|
|
|
self.segformer = smp.Segformer( |
|
encoder_name=encoder_name, |
|
encoder_weights=encoder_weights, |
|
in_channels=in_channels, |
|
classes=num_classes, |
|
) |
|
|
|
if freeze_encoder: |
|
for p in self.segformer.encoder.parameters(): |
|
p.requires_grad = False |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass. |
|
|
|
Args: |
|
x: Tensor of shape (B, in_channels, H, W). |
|
|
|
Returns: |
|
torch.Tensor: Logits of shape (B, num_classes, H, W). |
|
""" |
|
return self.segformer(x) |
|
|
|
@torch.no_grad() |
|
def predict(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Inference helper: applies softmax + argmax to produce label maps. |
|
|
|
Args: |
|
x: Tensor of shape (B, in_channels, H, W). |
|
|
|
Returns: |
|
torch.Tensor: Integer labels of shape (B, H, W). |
|
""" |
|
self.eval() |
|
logits = self.forward(x) |
|
return torch.softmax(logits, dim=1).argmax(dim=1) |