Verah commited on
Commit
f028cfc
·
verified ·
1 Parent(s): 6aa12b7

Upload denoise_util.py

Browse files
Files changed (1) hide show
  1. denoise_util.py +410 -0
denoise_util.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # SOURCE: https://github.com/Ascend-Research/CascadedGaze
5
+ # ------------------------------------------------------------------------
6
+ # Modified from NAFNet (https://github.com/megvii-research/NAFNet)
7
+ # ------------------------------------------------------------------------
8
+ import torch.nn.functional as F
9
+
10
+ class LayerNormFunction(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(ctx, x, weight, bias, eps):
13
+ ctx.eps = eps
14
+ N, C, H, W = x.size()
15
+ mu = x.mean(1, keepdim=True)
16
+ var = (x - mu).pow(2).mean(1, keepdim=True)
17
+ y = (x - mu) / (var + eps).sqrt()
18
+ ctx.save_for_backward(y, var, weight)
19
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
20
+ return y
21
+
22
+ @staticmethod
23
+ def backward(ctx, grad_output):
24
+ eps = ctx.eps
25
+
26
+ N, C, H, W = grad_output.size()
27
+ y, var, weight = ctx.saved_variables
28
+ g = grad_output * weight.view(1, C, 1, 1)
29
+ mean_g = g.mean(dim=1, keepdim=True)
30
+
31
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
32
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
33
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
34
+ dim=0), None
35
+
36
+ class LayerNorm2d(nn.Module):
37
+ def __init__(self, channels, eps=1e-6):
38
+ super(LayerNorm2d, self).__init__()
39
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
40
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
41
+ self.eps = eps
42
+
43
+ def forward(self, x):
44
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
45
+
46
+ class AvgPool2d(nn.Module):
47
+ def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
48
+ super().__init__()
49
+ self.kernel_size = kernel_size
50
+ self.base_size = base_size
51
+ self.auto_pad = auto_pad
52
+
53
+ # only used for fast implementation
54
+ self.fast_imp = fast_imp
55
+ self.rs = [5, 4, 3, 2, 1]
56
+ self.max_r1 = self.rs[0]
57
+ self.max_r2 = self.rs[0]
58
+ self.train_size = train_size
59
+
60
+ def extra_repr(self) -> str:
61
+ return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
62
+ self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
63
+ )
64
+
65
+ def forward(self, x):
66
+ if self.kernel_size is None and self.base_size:
67
+ train_size = self.train_size
68
+ if isinstance(self.base_size, int):
69
+ self.base_size = (self.base_size, self.base_size)
70
+ self.kernel_size = list(self.base_size)
71
+ self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
72
+ self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
73
+
74
+ # only used for fast implementation
75
+ self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
76
+ self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
77
+
78
+ if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
79
+ return F.adaptive_avg_pool2d(x, 1)
80
+
81
+ if self.fast_imp: # Non-equivalent implementation but faster
82
+ h, w = x.shape[2:]
83
+ if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
84
+ out = F.adaptive_avg_pool2d(x, 1)
85
+ else:
86
+ r1 = [r for r in self.rs if h % r == 0][0]
87
+ r2 = [r for r in self.rs if w % r == 0][0]
88
+ # reduction_constraint
89
+ r1 = min(self.max_r1, r1)
90
+ r2 = min(self.max_r2, r2)
91
+ s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
92
+ n, c, h, w = s.shape
93
+ k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
94
+ out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
95
+ out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
96
+ else:
97
+ n, c, h, w = x.shape
98
+ s = x.cumsum(dim=-1).cumsum_(dim=-2)
99
+ s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
100
+ k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
101
+ s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
102
+ out = s4 + s1 - s2 - s3
103
+ out = out / (k1 * k2)
104
+
105
+ if self.auto_pad:
106
+ n, c, h, w = x.shape
107
+ _h, _w = out.shape[2:]
108
+ # print(x.shape, self.kernel_size)
109
+ pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
110
+ out = torch.nn.functional.pad(out, pad2d, mode='replicate')
111
+
112
+ return out
113
+
114
+ def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
115
+ for n, m in model.named_children():
116
+ if len(list(m.children())) > 0:
117
+ ## compound module, go inside it
118
+ replace_layers(m, base_size, train_size, fast_imp, **kwargs)
119
+
120
+ if isinstance(m, nn.AdaptiveAvgPool2d):
121
+ # print(base_size)
122
+ pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
123
+ assert m.output_size == 1
124
+ setattr(model, n, pool)
125
+
126
+ '''
127
+ ref.
128
+ @article{chu2021tlsc,
129
+ title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
130
+ author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
131
+ journal={arXiv preprint arXiv:2112.04491},
132
+ year={2021}
133
+ }
134
+ '''
135
+ class Local_Base():
136
+ def convert(self, *args, train_size, **kwargs):
137
+ replace_layers(self, *args, train_size=train_size, **kwargs)
138
+ imgs = torch.rand(train_size)
139
+ with torch.no_grad():
140
+ self.forward(imgs)
141
+
142
+ class SimpleGate(nn.Module):
143
+ def forward(self, x):
144
+ x1, x2 = x.chunk(2, dim=1)
145
+ return x1 * x2
146
+
147
+ class depthwise_separable_conv(nn.Module):
148
+ def __init__(self, nin, nout, kernel_size = 3, padding = 0, stide = 1, bias=False):
149
+ super(depthwise_separable_conv, self).__init__()
150
+ self.pointwise = nn.Conv2d(nin, nout, kernel_size=1, bias=bias)
151
+ self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, stride=stide, padding=padding, groups=nin, bias=bias)
152
+
153
+ def forward(self, x):
154
+ x = self.depthwise(x)
155
+ x = self.pointwise(x)
156
+ return x
157
+
158
+
159
+ class UpsampleWithFlops(nn.Upsample):
160
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
161
+ super(UpsampleWithFlops, self).__init__(size, scale_factor, mode, align_corners)
162
+ self.__flops__ = 0
163
+
164
+ def forward(self, input):
165
+ self.__flops__ += input.numel()
166
+ return super(UpsampleWithFlops, self).forward(input)
167
+
168
+
169
+ class GlobalContextExtractor(nn.Module):
170
+ def __init__(self, c, kernel_sizes=[3, 3, 5], strides=[3, 3, 5], padding=0, bias=False):
171
+ super(GlobalContextExtractor, self).__init__()
172
+
173
+ self.depthwise_separable_convs = nn.ModuleList([
174
+ depthwise_separable_conv(c, c, kernel_size, padding, stride, bias)
175
+ for kernel_size, stride in zip(kernel_sizes, strides)
176
+ ])
177
+
178
+ def forward(self, x):
179
+ outputs = []
180
+ for conv in self.depthwise_separable_convs:
181
+ x = F.gelu(conv(x))
182
+ outputs.append(x)
183
+ return outputs
184
+
185
+
186
+ class CascadedGazeBlock(nn.Module):
187
+ def __init__(self, c, GCE_Conv =2, DW_Expand=2, FFN_Expand=2, drop_out_rate=0):
188
+ super().__init__()
189
+ self.dw_channel = c * DW_Expand
190
+ self.GCE_Conv = GCE_Conv
191
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1,
192
+ padding=0, stride=1, groups=1, bias=True)
193
+ self.conv2 = nn.Conv2d(in_channels=self.dw_channel, out_channels=self.dw_channel,
194
+ kernel_size=3, padding=1, stride=1, groups=self.dw_channel,
195
+ bias=True)
196
+
197
+
198
+ if self.GCE_Conv == 3:
199
+ self.GCE = GlobalContextExtractor(c=c, kernel_sizes=[3, 3, 5], strides=[2, 3, 4])
200
+
201
+ self.project_out = nn.Conv2d(int(self.dw_channel*2.5), c, kernel_size=1)
202
+
203
+ self.sca = nn.Sequential(
204
+ nn.AdaptiveAvgPool2d(1),
205
+ nn.Conv2d(in_channels=int(self.dw_channel*2.5), out_channels=int(self.dw_channel*2.5), kernel_size=1, padding=0, stride=1,
206
+ groups=1, bias=True))
207
+ else:
208
+ self.GCE = GlobalContextExtractor(c=c, kernel_sizes=[3, 3], strides=[2, 3])
209
+
210
+ self.project_out = nn.Conv2d(self.dw_channel*2, c, kernel_size=1)
211
+
212
+ self.sca = nn.Sequential(
213
+ nn.AdaptiveAvgPool2d(1),
214
+ nn.Conv2d(in_channels=self.dw_channel*2, out_channels=self.dw_channel*2, kernel_size=1, padding=0, stride=1,
215
+ groups=1, bias=True))
216
+
217
+
218
+ # SimpleGate
219
+ self.sg = SimpleGate()
220
+
221
+ ffn_channel = FFN_Expand * c
222
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
223
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
224
+
225
+ self.norm1 = LayerNorm2d(c)
226
+ self.norm2 = LayerNorm2d(c)
227
+
228
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
229
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
230
+
231
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
232
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
233
+
234
+ def forward(self, inp):
235
+ x = inp
236
+ b,c,h,w = x.shape
237
+ # # Nearest neighbor upsampling as part of the range fusion process
238
+ self.upsample = UpsampleWithFlops(size=(h,w), mode='nearest')
239
+
240
+
241
+ x = self.norm1(x)
242
+ x = self.conv1(x)
243
+ x = self.conv2(x)
244
+ x = F.gelu(x)
245
+
246
+
247
+ # Global Context Extractor + Range fusion
248
+ x_1 , x_2 = x.chunk(2, dim=1)
249
+ if self.GCE_Conv == 3:
250
+ x1, x2, x3 = self.GCE(x_1 + x_2)
251
+ x = torch.cat([x, self.upsample(x1), self.upsample(x2), self.upsample(x3)], dim = 1)
252
+ else:
253
+ x1, x2 = self.GCE(x_1 + x_2)
254
+ x = torch.cat([x, self.upsample(x1), self.upsample(x2)], dim = 1)
255
+ x = self.sca(x) * x
256
+ x = self.project_out(x)
257
+
258
+
259
+ x = self.dropout1(x)
260
+ #channel-mixing
261
+ y = inp + x * self.beta
262
+ x = self.conv4(self.norm2(y))
263
+ x = self.sg(x)
264
+ x = self.conv5(x)
265
+ x = self.dropout2(x)
266
+
267
+ return y + x * self.gamma
268
+
269
+ class NAFBlock0(nn.Module):
270
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0):
271
+ super().__init__()
272
+ dw_channel = c * DW_Expand
273
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
274
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
275
+ bias=True)
276
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
277
+
278
+ # Simplified Channel Attention
279
+ self.sca = nn.Sequential(
280
+ nn.AdaptiveAvgPool2d(1),
281
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
282
+ groups=1, bias=True),
283
+ )
284
+
285
+ # SimpleGate
286
+ self.sg = SimpleGate()
287
+
288
+ ffn_channel = FFN_Expand * c
289
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
290
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
291
+
292
+ self.norm1 = LayerNorm2d(c)
293
+ self.norm2 = LayerNorm2d(c)
294
+
295
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
296
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
297
+
298
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
299
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
300
+
301
+ def forward(self, inp):
302
+ x = inp
303
+
304
+ x = self.norm1(x)
305
+
306
+ x = self.conv1(x)
307
+ x = self.conv2(x)
308
+ x = self.sg(x)
309
+ x = x * self.sca(x)
310
+ x = self.conv3(x)
311
+
312
+ x = self.dropout1(x)
313
+
314
+ y = inp + x * self.beta
315
+
316
+ #Channel Mixing
317
+ x = self.conv4(self.norm2(y))
318
+ x = self.sg(x)
319
+ x = self.conv5(x)
320
+
321
+ x = self.dropout2(x)
322
+
323
+ return y + x * self.gamma
324
+
325
+
326
+ class CascadedGaze(nn.Module):
327
+
328
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], GCE_CONVS_nums=[]):
329
+ super().__init__()
330
+
331
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
332
+ bias=True)
333
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
334
+ bias=True)
335
+
336
+ self.encoders = nn.ModuleList()
337
+ self.decoders = nn.ModuleList()
338
+ self.middle_blks = nn.ModuleList()
339
+ self.ups = nn.ModuleList()
340
+ self.downs = nn.ModuleList()
341
+
342
+ chan = width
343
+ # for num in enc_blk_nums:
344
+ for i in range(len(enc_blk_nums)):
345
+ num = enc_blk_nums[i]
346
+ GCE_Convs = GCE_CONVS_nums[i]
347
+ self.encoders.append(
348
+ nn.Sequential(
349
+ *[CascadedGazeBlock(chan, GCE_Conv=GCE_Convs) for _ in range(num)]
350
+ )
351
+ )
352
+ self.downs.append(
353
+ nn.Conv2d(chan, 2*chan, 2, 2)
354
+ )
355
+ chan = chan * 2
356
+
357
+ self.middle_blks = \
358
+ nn.Sequential(
359
+ *[NAFBlock0(chan) for _ in range(middle_blk_num)]
360
+ )
361
+
362
+ for i in range(len(dec_blk_nums)):
363
+ num = dec_blk_nums[i]
364
+ self.ups.append(
365
+ nn.Sequential(
366
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
367
+ nn.PixelShuffle(2)
368
+ )
369
+ )
370
+ chan = chan // 2
371
+ self.decoders.append(
372
+ nn.Sequential(
373
+ *[NAFBlock0(chan) for _ in range(num)]
374
+ )
375
+ )
376
+
377
+ self.padder_size = 2 ** len(self.encoders)
378
+
379
+ def forward(self, inp):
380
+ B, C, H, W = inp.shape
381
+ inp = self.check_image_size(inp)
382
+
383
+ x = self.intro(inp)
384
+
385
+ encs = []
386
+
387
+ for encoder, down in zip(self.encoders, self.downs):
388
+ x = encoder(x)
389
+ encs.append(x)
390
+ x = down(x)
391
+
392
+ x = self.middle_blks(x)
393
+
394
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
395
+ x = up(x)
396
+ x = x + enc_skip
397
+ x = decoder(x)
398
+
399
+ x = self.ending(x)
400
+ x = x + inp
401
+
402
+ return x[:, :, :H, :W]
403
+
404
+ def check_image_size(self, x):
405
+ _, _, h, w = x.size()
406
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
407
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
408
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
409
+ return x
410
+