import torch import torch.nn as nn import segmentation_models_pytorch as smp from typing import Optional class SegFormer(nn.Module): """ SegFormer model for multi-class semantic segmentation. Default setup targets RGB (3 bands), but you can set `in_channels` to support multispectral inputs (e.g., 13 for Sentinel-2 L1C). Outputs raw logits with shape (B, num_classes, H, W). """ def __init__( self, encoder_name: str = "mit_b4", encoder_weights: Optional[str] = "imagenet", # set to None if incompatible with in_channels in_channels: int = 3, num_classes: int = 4, freeze_encoder: bool = False, ) -> None: """ Args: encoder_name: TIMM encoder name (e.g., 'mit_b0'...'mit_b5', default 'mit_b4'). encoder_weights: Pretrained weights name (typically 'imagenet' or None). in_channels: Number of input channels (3 for RGB, 13 for Sentinel-2, etc.). num_classes: Number of output classes for segmentation. freeze_encoder: If True, freezes encoder parameters during training. """ super().__init__() self.segformer = smp.Segformer( encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=in_channels, classes=num_classes, ) if freeze_encoder: for p in self.segformer.encoder.parameters(): p.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. Args: x: Tensor of shape (B, in_channels, H, W). Returns: torch.Tensor: Logits of shape (B, num_classes, H, W). """ return self.segformer(x) @torch.no_grad() def predict(self, x: torch.Tensor) -> torch.Tensor: """ Inference helper: applies softmax + argmax to produce label maps. Args: x: Tensor of shape (B, in_channels, H, W). Returns: torch.Tensor: Integer labels of shape (B, H, W). """ self.eval() logits = self.forward(x) # (B, num_classes, H, W) return torch.softmax(logits, dim=1).argmax(dim=1)