MRiabov commited on
Commit
a2999cc
·
1 Parent(s): e3ab023

Add ResNet-50 as a backbone option

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. src/wireseghr/model/encoder.py +65 -49
README.md CHANGED
@@ -32,7 +32,7 @@ python src/wireseghr/infer.py --config configs/default.yaml --image /path/to/ima
32
 
33
  ### Backbone Source
34
  - HuggingFace Transformers SegFormer (e.g., `nvidia/mit-b3`). We set `num_channels` to match input channels.
35
- - Fallback: a small internal CNN that preserves 1/4, 1/8, 1/16, 1/32 strides with channels [64, 128, 320, 512].
36
 
37
  ## Dataset Convention
38
  - Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
 
32
 
33
  ### Backbone Source
34
  - HuggingFace Transformers SegFormer (e.g., `nvidia/mit-b3`). We set `num_channels` to match input channels.
35
+ - Alternative: TorchVision ResNet-50 (`backbone: resnet50`). The stem is adapted to the requested `in_channels`, and we expose features from `layer1`..`layer4` at strides 1/4, 1/8, 1/16, 1/32 with channels [256, 512, 1024, 2048].
36
 
37
  ## Dataset Convention
38
  - Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
src/wireseghr/model/encoder.py CHANGED
@@ -1,14 +1,18 @@
1
- """SegFormer MiT encoder wrapper with adjustable input channels.
2
 
3
- Uses HuggingFace Transformers SegFormer (e.g., mit_b2) and returns a list of
4
- multi-scale features [C1, C2, C3, C4]. Falls back to a tiny CNN if HF isn't
5
- available.
 
 
 
6
  """
7
 
8
  from typing import List, Tuple
9
 
10
  import torch
11
  import torch.nn as nn
 
12
 
13
 
14
  class SegFormerEncoder(nn.Module):
@@ -23,62 +27,74 @@ class SegFormerEncoder(nn.Module):
23
  self.in_channels = in_channels
24
  self.pretrained = pretrained
25
 
26
- # Prefer HuggingFace SegFormer for 'mit_*' backbones.
27
- # Fallback to Tiny CNN if HF unavailable or unsupported.
28
  self.hf = None
29
- prefer_hf = backbone.startswith("mit_") or backbone.startswith("segformer")
30
- if prefer_hf:
31
- # HF -> tiny
32
- try:
33
- self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
34
- self.feature_dims = self.hf.feature_dims
35
- except Exception:
36
- self.hf = None
37
- self.fallback = _TinyEncoder(in_channels)
38
- self.feature_dims = [64, 128, 320, 512]
39
  else:
40
- # tiny
41
- self.fallback = _TinyEncoder(in_channels)
42
- self.feature_dims = [64, 128, 320, 512]
43
 
44
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
45
  if self.hf is not None:
46
  return self.hf(x)
47
- else:
48
- return self.fallback(x)
 
49
 
50
 
51
- class _TinyEncoder(nn.Module):
52
- def __init__(self, in_chans: int):
53
  super().__init__()
54
- # Output strides: 4, 8, 16, 32 with channels 64,128,320,512
55
- self.stem = nn.Sequential(
56
- nn.Conv2d(in_chans, 64, kernel_size=7, stride=4, padding=3, bias=False),
57
- nn.BatchNorm2d(64),
58
- nn.ReLU(inplace=True),
59
- )
60
- self.stage1 = nn.Sequential(
61
- nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
62
- nn.BatchNorm2d(128),
63
- nn.ReLU(inplace=True),
64
- )
65
- self.stage2 = nn.Sequential(
66
- nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1, bias=False),
67
- nn.BatchNorm2d(320),
68
- nn.ReLU(inplace=True),
69
- )
70
- self.stage3 = nn.Sequential(
71
- nn.Conv2d(320, 512, kernel_size=3, stride=2, padding=1, bias=False),
72
- nn.BatchNorm2d(512),
73
- nn.ReLU(inplace=True),
74
- )
 
 
 
 
 
 
75
 
76
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
77
- c0 = self.stem(x) # 1/4
78
- c1 = self.stage1(c0) # 1/8
79
- c2 = self.stage2(c1) # 1/16
80
- c3 = self.stage3(c2) # 1/32
81
- return [c0, c1, c2, c3]
 
 
 
 
 
 
 
82
 
83
 
84
  class _HFEncoderWrapper(nn.Module):
 
1
+ """Encoder wrappers with adjustable input channels.
2
 
3
+ Supports two backbone families:
4
+ - HuggingFace Transformers SegFormer (e.g., "mit_b2")
5
+ - TorchVision ResNet-50 (use backbone "resnet50" | "resnet-50" | "resnet_50")
6
+
7
+ Both return a list of 4 multi-scale feature maps [C1, C2, C3, C4] at strides
8
+ 1/4, 1/8, 1/16, 1/32 respectively.
9
  """
10
 
11
  from typing import List, Tuple
12
 
13
  import torch
14
  import torch.nn as nn
15
+ from torchvision.models import resnet50, ResNet50_Weights
16
 
17
 
18
  class SegFormerEncoder(nn.Module):
 
27
  self.in_channels = in_channels
28
  self.pretrained = pretrained
29
 
 
 
30
  self.hf = None
31
+ self.resnet = None
32
+
33
+ # SegFormer path
34
+ if backbone.startswith("mit_") or backbone.startswith("segformer"):
35
+ self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
36
+ self.feature_dims = self.hf.feature_dims
37
+ # ResNet-50 path
38
+ elif backbone in ("resnet50", "resnet-50", "resnet_50"):
39
+ self.resnet = _ResNetEncoderWrapper(in_channels, pretrained)
40
+ self.feature_dims = self.resnet.feature_dims
41
  else:
42
+ raise ValueError(
43
+ f"Unsupported backbone '{backbone}'. Use one of: mit_b[0-5], segformer*, resnet50."
44
+ )
45
 
46
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
47
  if self.hf is not None:
48
  return self.hf(x)
49
+ if self.resnet is not None:
50
+ return self.resnet(x)
51
+ raise AssertionError("No encoder instantiated")
52
 
53
 
54
+ class _ResNetEncoderWrapper(nn.Module):
55
+ def __init__(self, in_chans: int, pretrained: bool):
56
  super().__init__()
57
+ # Build base ResNet-50
58
+ if pretrained:
59
+ self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
60
+ else:
61
+ self.model = resnet50(weights=None)
62
+
63
+ # Adjust input stem for arbitrary channel count
64
+ if in_chans != 3:
65
+ old_conv = self.model.conv1
66
+ new_conv = nn.Conv2d(
67
+ in_chans, old_conv.out_channels, kernel_size=old_conv.kernel_size[0],
68
+ stride=old_conv.stride[0], padding=old_conv.padding[0], bias=False
69
+ )
70
+ with torch.no_grad():
71
+ if pretrained and old_conv.weight.shape[1] == 3:
72
+ w = old_conv.weight # [64, 3, 7, 7]
73
+ if in_chans > 3:
74
+ w_mean = w.mean(dim=1, keepdim=True)
75
+ new_w = w_mean.repeat(1, in_chans, 1, 1)
76
+ else:
77
+ new_w = w[:, :in_chans, :, :]
78
+ new_conv.weight.copy_(new_w)
79
+ else:
80
+ nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")
81
+ self.model.conv1 = new_conv
82
+
83
+ self.feature_dims = [256, 512, 1024, 2048]
84
 
85
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
86
+ # Stem
87
+ x = self.model.conv1(x)
88
+ x = self.model.bn1(x)
89
+ x = self.model.relu(x)
90
+ x = self.model.maxpool(x) # 1/4
91
+
92
+ # Stages
93
+ c1 = self.model.layer1(x) # 1/4, 256
94
+ c2 = self.model.layer2(c1) # 1/8, 512
95
+ c3 = self.model.layer3(c2) # 1/16, 1024
96
+ c4 = self.model.layer4(c3) # 1/32, 2048
97
+ return [c1, c2, c3, c4]
98
 
99
 
100
  class _HFEncoderWrapper(nn.Module):