zhengchong commited on
Commit
47e441f
1 Parent(s): fe2cfb5

chore: Update SCHP model checkpoint loading logic

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ playground.py
2
+ __pycache__
model/SCHP/__init__.py CHANGED
@@ -81,12 +81,27 @@ class SCHP:
81
 
82
 
83
  def load_ckpt(self, ckpt_path):
 
 
 
 
 
 
 
 
 
84
  state_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
85
  new_state_dict = OrderedDict()
86
  for k, v in state_dict.items():
87
  name = k[7:] # remove `module.`
88
  new_state_dict[name] = v
89
- self.model.load_state_dict(new_state_dict)
 
 
 
 
 
 
90
 
91
  def _box2cs(self, box):
92
  x, y, w, h = box[:4]
@@ -148,7 +163,8 @@ class SCHP:
148
  meta_list = [meta]
149
 
150
  output = self.model(image)
151
- upsample_outputs = self.upsample(output[0][-1])
 
152
  upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
153
 
154
  output_img_list = []
 
81
 
82
 
83
  def load_ckpt(self, ckpt_path):
84
+ rename_map = {
85
+ "decoder.conv3.2.weight": "decoder.conv3.3.weight",
86
+ "decoder.conv3.3.weight": "decoder.conv3.4.weight",
87
+ "decoder.conv3.3.bias": "decoder.conv3.4.bias",
88
+ "decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
89
+ "decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
90
+ "fushion.3.weight": "fushion.4.weight",
91
+ "fushion.3.bias": "fushion.4.bias",
92
+ }
93
  state_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
94
  new_state_dict = OrderedDict()
95
  for k, v in state_dict.items():
96
  name = k[7:] # remove `module.`
97
  new_state_dict[name] = v
98
+ new_state_dict_ = OrderedDict()
99
+ for k, v in list(new_state_dict.items()):
100
+ if k in rename_map:
101
+ new_state_dict_[rename_map[k]] = v
102
+ else:
103
+ new_state_dict_[k] = v
104
+ self.model.load_state_dict(new_state_dict_, strict=False)
105
 
106
  def _box2cs(self, box):
107
  x, y, w, h = box[:4]
 
163
  meta_list = [meta]
164
 
165
  output = self.model(image)
166
+ # upsample_outputs = self.upsample(output[0][-1])
167
+ upsample_outputs = self.upsample(output)
168
  upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
169
 
170
  output_img_list = []
model/SCHP/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/model/SCHP/__pycache__/__init__.cpython-39.pyc and b/model/SCHP/__pycache__/__init__.cpython-39.pyc differ
 
model/SCHP/networks/AugmentCE2P.py CHANGED
@@ -11,19 +11,13 @@
11
  LICENSE file in the root directory of this source tree.
12
  """
13
 
14
- import functools
15
-
16
  import torch
17
  import torch.nn as nn
18
  from torch.nn import functional as F
19
- # Note here we adopt the InplaceABNSync implementation from https://github.com/mapillary/inplace_abn
20
- # By default, the InplaceABNSync module contains a BatchNorm Layer and a LeakyReLu layer
21
- from inplace_abn import InPlaceABNSync
22
 
23
- BatchNorm2d = functools.partial(InPlaceABNSync, activation='identity')
24
 
25
  affine_par = True
26
-
27
  pretrained_settings = {
28
  'resnet101': {
29
  'imagenet': {
@@ -99,14 +93,20 @@ class PSPModule(nn.Module):
99
  self.bottleneck = nn.Sequential(
100
  nn.Conv2d(features + len(sizes) * out_features, out_features, kernel_size=3, padding=1, dilation=1,
101
  bias=False),
102
- InPlaceABNSync(out_features),
 
103
  )
104
 
105
  def _make_stage(self, features, out_features, size):
106
  prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
107
  conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
108
- bn = InPlaceABNSync(out_features)
109
- return nn.Sequential(prior, conv, bn)
 
 
 
 
 
110
 
111
  def forward(self, feats):
112
  h, w = feats.size(2), feats.size(3)
@@ -128,23 +128,35 @@ class ASPPModule(nn.Module):
128
  self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
129
  nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1,
130
  bias=False),
131
- InPlaceABNSync(inner_features))
 
 
 
132
  self.conv2 = nn.Sequential(
133
  nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False),
134
- InPlaceABNSync(inner_features))
 
 
135
  self.conv3 = nn.Sequential(
136
  nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
137
- InPlaceABNSync(inner_features))
 
 
138
  self.conv4 = nn.Sequential(
139
  nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
140
- InPlaceABNSync(inner_features))
 
 
141
  self.conv5 = nn.Sequential(
142
  nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
143
- InPlaceABNSync(inner_features))
 
 
144
 
145
  self.bottleneck = nn.Sequential(
146
  nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
147
- InPlaceABNSync(out_features),
 
148
  nn.Dropout2d(0.1)
149
  )
150
 
@@ -173,24 +185,27 @@ class Edge_Module(nn.Module):
173
 
174
  self.conv1 = nn.Sequential(
175
  nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
176
- InPlaceABNSync(mid_fea)
 
177
  )
178
  self.conv2 = nn.Sequential(
179
  nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
180
- InPlaceABNSync(mid_fea)
 
181
  )
182
  self.conv3 = nn.Sequential(
183
  nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
184
- InPlaceABNSync(mid_fea)
 
185
  )
186
  self.conv4 = nn.Conv2d(mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True)
187
- self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
188
 
189
  def forward(self, x1, x2, x3):
190
  _, _, h, w = x1.size()
191
 
192
  edge1_fea = self.conv1(x1)
193
- edge1 = self.conv4(edge1_fea)
194
  edge2_fea = self.conv2(x2)
195
  edge2 = self.conv4(edge2_fea)
196
  edge3_fea = self.conv3(x3)
@@ -201,11 +216,12 @@ class Edge_Module(nn.Module):
201
  edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear', align_corners=True)
202
  edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear', align_corners=True)
203
 
204
- edge = torch.cat([edge1, edge2, edge3], dim=1)
205
  edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
206
- edge = self.conv5(edge)
207
 
208
- return edge, edge_fea
 
209
 
210
 
211
  class Decoder_Module(nn.Module):
@@ -217,20 +233,24 @@ class Decoder_Module(nn.Module):
217
  super(Decoder_Module, self).__init__()
218
  self.conv1 = nn.Sequential(
219
  nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
220
- InPlaceABNSync(256)
 
221
  )
222
  self.conv2 = nn.Sequential(
223
  nn.Conv2d(256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
224
- InPlaceABNSync(48)
 
225
  )
226
  self.conv3 = nn.Sequential(
227
  nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
228
- InPlaceABNSync(256),
 
229
  nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
230
- InPlaceABNSync(256)
 
231
  )
232
 
233
- self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
234
 
235
  def forward(self, xt, xl):
236
  _, _, h, w = xl.size()
@@ -238,8 +258,9 @@ class Decoder_Module(nn.Module):
238
  xl = self.conv2(xl)
239
  x = torch.cat([xt, xl], dim=1)
240
  x = self.conv3(x)
241
- seg = self.conv4(x)
242
- return seg, x
 
243
 
244
 
245
  class ResNet(nn.Module):
@@ -270,7 +291,8 @@ class ResNet(nn.Module):
270
 
271
  self.fushion = nn.Sequential(
272
  nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
273
- InPlaceABNSync(256),
 
274
  nn.Dropout2d(0.1),
275
  nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
276
  )
@@ -304,13 +326,16 @@ class ResNet(nn.Module):
304
  x4 = self.layer3(x3)
305
  x5 = self.layer4(x4)
306
  x = self.context_encoding(x5)
307
- parsing_result, parsing_fea = self.decoder(x, x2)
 
308
  # Edge Branch
309
- edge_result, edge_fea = self.edge(x2, x3, x4)
 
310
  # Fusion Branch
311
  x = torch.cat([parsing_fea, edge_fea], dim=1)
312
  fusion_result = self.fushion(x)
313
- return [[parsing_result, fusion_result], [edge_result]]
 
314
 
315
 
316
  def initialize_pretrained_model(model, settings, pretrained='./models/resnet101-imagenet.pth'):
 
11
  LICENSE file in the root directory of this source tree.
12
  """
13
 
 
 
14
  import torch
15
  import torch.nn as nn
16
  from torch.nn import functional as F
 
 
 
17
 
18
+ from torch.nn import BatchNorm2d, LeakyReLU
19
 
20
  affine_par = True
 
21
  pretrained_settings = {
22
  'resnet101': {
23
  'imagenet': {
 
93
  self.bottleneck = nn.Sequential(
94
  nn.Conv2d(features + len(sizes) * out_features, out_features, kernel_size=3, padding=1, dilation=1,
95
  bias=False),
96
+ BatchNorm2d(out_features),
97
+ LeakyReLU(),
98
  )
99
 
100
  def _make_stage(self, features, out_features, size):
101
  prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
102
  conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
103
+ return nn.Sequential(
104
+ prior,
105
+ conv,
106
+ # bn
107
+ BatchNorm2d(out_features),
108
+ LeakyReLU(),
109
+ )
110
 
111
  def forward(self, feats):
112
  h, w = feats.size(2), feats.size(3)
 
128
  self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
129
  nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1,
130
  bias=False),
131
+ # InPlaceABNSync(inner_features)
132
+ BatchNorm2d(inner_features),
133
+ LeakyReLU(),
134
+ )
135
  self.conv2 = nn.Sequential(
136
  nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False),
137
+ BatchNorm2d(inner_features),
138
+ LeakyReLU(),
139
+ )
140
  self.conv3 = nn.Sequential(
141
  nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
142
+ BatchNorm2d(inner_features),
143
+ LeakyReLU(),
144
+ )
145
  self.conv4 = nn.Sequential(
146
  nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
147
+ BatchNorm2d(inner_features),
148
+ LeakyReLU(),
149
+ )
150
  self.conv5 = nn.Sequential(
151
  nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
152
+ BatchNorm2d(inner_features),
153
+ LeakyReLU(),
154
+ )
155
 
156
  self.bottleneck = nn.Sequential(
157
  nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
158
+ BatchNorm2d(inner_features),
159
+ LeakyReLU(),
160
  nn.Dropout2d(0.1)
161
  )
162
 
 
185
 
186
  self.conv1 = nn.Sequential(
187
  nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
188
+ BatchNorm2d(mid_fea),
189
+ LeakyReLU(),
190
  )
191
  self.conv2 = nn.Sequential(
192
  nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
193
+ BatchNorm2d(mid_fea),
194
+ LeakyReLU(),
195
  )
196
  self.conv3 = nn.Sequential(
197
  nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
198
+ BatchNorm2d(mid_fea),
199
+ LeakyReLU(),
200
  )
201
  self.conv4 = nn.Conv2d(mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True)
202
+ # self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
203
 
204
  def forward(self, x1, x2, x3):
205
  _, _, h, w = x1.size()
206
 
207
  edge1_fea = self.conv1(x1)
208
+ # edge1 = self.conv4(edge1_fea)
209
  edge2_fea = self.conv2(x2)
210
  edge2 = self.conv4(edge2_fea)
211
  edge3_fea = self.conv3(x3)
 
216
  edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear', align_corners=True)
217
  edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear', align_corners=True)
218
 
219
+ # edge = torch.cat([edge1, edge2, edge3], dim=1)
220
  edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
221
+ # edge = self.conv5(edge)
222
 
223
+ # return edge, edge_fea
224
+ return edge_fea
225
 
226
 
227
  class Decoder_Module(nn.Module):
 
233
  super(Decoder_Module, self).__init__()
234
  self.conv1 = nn.Sequential(
235
  nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
236
+ BatchNorm2d(256),
237
+ LeakyReLU(),
238
  )
239
  self.conv2 = nn.Sequential(
240
  nn.Conv2d(256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
241
+ BatchNorm2d(48),
242
+ LeakyReLU(),
243
  )
244
  self.conv3 = nn.Sequential(
245
  nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
246
+ BatchNorm2d(256),
247
+ LeakyReLU(),
248
  nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
249
+ BatchNorm2d(256),
250
+ LeakyReLU(),
251
  )
252
 
253
+ # self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
254
 
255
  def forward(self, xt, xl):
256
  _, _, h, w = xl.size()
 
258
  xl = self.conv2(xl)
259
  x = torch.cat([xt, xl], dim=1)
260
  x = self.conv3(x)
261
+ # seg = self.conv4(x)
262
+ # return seg, x
263
+ return x
264
 
265
 
266
  class ResNet(nn.Module):
 
291
 
292
  self.fushion = nn.Sequential(
293
  nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
294
+ BatchNorm2d(256),
295
+ LeakyReLU(),
296
  nn.Dropout2d(0.1),
297
  nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
298
  )
 
326
  x4 = self.layer3(x3)
327
  x5 = self.layer4(x4)
328
  x = self.context_encoding(x5)
329
+ # parsing_result, parsing_fea = self.decoder(x, x2)
330
+ parsing_fea = self.decoder(x, x2)
331
  # Edge Branch
332
+ # edge_result, edge_fea = self.edge(x2, x3, x4)
333
+ edge_fea = self.edge(x2, x3, x4)
334
  # Fusion Branch
335
  x = torch.cat([parsing_fea, edge_fea], dim=1)
336
  fusion_result = self.fushion(x)
337
+ # return [[parsing_result, fusion_result], [edge_result]]
338
+ return fusion_result
339
 
340
 
341
  def initialize_pretrained_model(model, settings, pretrained='./models/resnet101-imagenet.pth'):
model/SCHP/networks/__pycache__/AugmentCE2P.cpython-39.pyc CHANGED
Binary files a/model/SCHP/networks/__pycache__/AugmentCE2P.cpython-39.pyc and b/model/SCHP/networks/__pycache__/AugmentCE2P.cpython-39.pyc differ
 
model/SCHP/networks/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/model/SCHP/networks/__pycache__/__init__.cpython-39.pyc and b/model/SCHP/networks/__pycache__/__init__.cpython-39.pyc differ
 
model/SCHP/utils/__pycache__/transforms.cpython-39.pyc CHANGED
Binary files a/model/SCHP/utils/__pycache__/transforms.cpython-39.pyc and b/model/SCHP/utils/__pycache__/transforms.cpython-39.pyc differ