Burdenthrive's picture
Update model.py
fdd0e8b verified
raw
history blame
2.27 kB
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from typing import Optional
class SegFormer(nn.Module):
"""
SegFormer model for multi-class semantic segmentation.
Default setup targets RGB (3 bands), but you can set `in_channels` to support
multispectral inputs (e.g., 13 for Sentinel-2 L1C). Outputs raw logits with
shape (B, num_classes, H, W).
"""
def __init__(
self,
encoder_name: str = "mit_b4",
encoder_weights: Optional[str] = "imagenet", # set to None if incompatible with in_channels
in_channels: int = 3,
num_classes: int = 4,
freeze_encoder: bool = False,
) -> None:
"""
Args:
encoder_name: TIMM encoder name (e.g., 'mit_b0'...'mit_b5', default 'mit_b4').
encoder_weights: Pretrained weights name (typically 'imagenet' or None).
in_channels: Number of input channels (3 for RGB, 13 for Sentinel-2, etc.).
num_classes: Number of output classes for segmentation.
freeze_encoder: If True, freezes encoder parameters during training.
"""
super().__init__()
self.segformer = smp.Segformer(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=num_classes,
)
if freeze_encoder:
for p in self.segformer.encoder.parameters():
p.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Tensor of shape (B, in_channels, H, W).
Returns:
torch.Tensor: Logits of shape (B, num_classes, H, W).
"""
return self.segformer(x)
@torch.no_grad()
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Inference helper: applies softmax + argmax to produce label maps.
Args:
x: Tensor of shape (B, in_channels, H, W).
Returns:
torch.Tensor: Integer labels of shape (B, H, W).
"""
self.eval()
logits = self.forward(x) # (B, num_classes, H, W)
return torch.softmax(logits, dim=1).argmax(dim=1)