Burdenthrive commited on
Commit
fdd0e8b
·
verified ·
1 Parent(s): d22dfdb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +70 -0
model.py CHANGED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from typing import Optional
5
+
6
+
7
+ class SegFormer(nn.Module):
8
+ """
9
+ SegFormer model for multi-class semantic segmentation.
10
+
11
+ Default setup targets RGB (3 bands), but you can set `in_channels` to support
12
+ multispectral inputs (e.g., 13 for Sentinel-2 L1C). Outputs raw logits with
13
+ shape (B, num_classes, H, W).
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ encoder_name: str = "mit_b4",
19
+ encoder_weights: Optional[str] = "imagenet", # set to None if incompatible with in_channels
20
+ in_channels: int = 3,
21
+ num_classes: int = 4,
22
+ freeze_encoder: bool = False,
23
+ ) -> None:
24
+ """
25
+ Args:
26
+ encoder_name: TIMM encoder name (e.g., 'mit_b0'...'mit_b5', default 'mit_b4').
27
+ encoder_weights: Pretrained weights name (typically 'imagenet' or None).
28
+ in_channels: Number of input channels (3 for RGB, 13 for Sentinel-2, etc.).
29
+ num_classes: Number of output classes for segmentation.
30
+ freeze_encoder: If True, freezes encoder parameters during training.
31
+ """
32
+ super().__init__()
33
+
34
+ self.segformer = smp.Segformer(
35
+ encoder_name=encoder_name,
36
+ encoder_weights=encoder_weights,
37
+ in_channels=in_channels,
38
+ classes=num_classes,
39
+ )
40
+
41
+ if freeze_encoder:
42
+ for p in self.segformer.encoder.parameters():
43
+ p.requires_grad = False
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Forward pass.
48
+
49
+ Args:
50
+ x: Tensor of shape (B, in_channels, H, W).
51
+
52
+ Returns:
53
+ torch.Tensor: Logits of shape (B, num_classes, H, W).
54
+ """
55
+ return self.segformer(x)
56
+
57
+ @torch.no_grad()
58
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Inference helper: applies softmax + argmax to produce label maps.
61
+
62
+ Args:
63
+ x: Tensor of shape (B, in_channels, H, W).
64
+
65
+ Returns:
66
+ torch.Tensor: Integer labels of shape (B, H, W).
67
+ """
68
+ self.eval()
69
+ logits = self.forward(x) # (B, num_classes, H, W)
70
+ return torch.softmax(logits, dim=1).argmax(dim=1)