vincent-doan commited on
Commit
fc75fbd
·
1 Parent(s): 9e41929

Added RCAN model

Browse files
models/RCAN/model.py DELETED
File without changes
models/RCAN/rcan.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ NUM_RESIDUAL_GROUPS = 8
4
+ NUM_RESIDUAL_BLOCKS = 16
5
+ KERNEL_SIZE = 3
6
+ REDUCTION_RATIO = 16
7
+ NUM_CHANNELS = 64
8
+ UPSCALE_FACTOR = 4
9
+
10
+ class ResidualChannelAttentionBlock(nn.Module):
11
+ def __init__(self, num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
12
+
13
+ super(ResidualChannelAttentionBlock, self).__init__()
14
+
15
+ self.feature_extractor = nn.Sequential(
16
+ nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2),
17
+ nn.ReLU(),
18
+ nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
19
+ )
20
+
21
+ self.channel_attention = nn.Sequential(
22
+ nn.AdaptiveAvgPool2d(1),
23
+ nn.Conv2d(num_channels, num_channels//reduction_ratio, kernel_size=1, stride=1),
24
+ nn.ReLU(),
25
+ # nn.BatchNorm2d(num_channels//reduction_ratio),
26
+ nn.Conv2d(num_channels//reduction_ratio, num_channels, kernel_size=1, stride=1),
27
+ nn.Sigmoid()
28
+ )
29
+
30
+ def forward(self, x):
31
+ block_input = x.clone()
32
+
33
+ residual = self.feature_extractor(x) # Feature extraction
34
+ rescale = self.channel_attention(residual) # Rescaling vector
35
+
36
+ block_output = block_input + (residual * rescale)
37
+
38
+ return block_output
39
+
40
+ class ResidualGroup(nn.Module):
41
+ def __init__(self, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
42
+ num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
43
+
44
+ super(ResidualGroup, self).__init__()
45
+
46
+ self.residual_blocks = nn.Sequential(
47
+ *[ResidualChannelAttentionBlock(num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
48
+ for _ in range(num_residual_blocks)]
49
+ )
50
+
51
+ self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
52
+
53
+ def forward(self, x):
54
+ group_input = x.clone()
55
+
56
+ residual = self.residual_blocks(x) # Residual blocks
57
+ residual = self.final_conv(residual) # Final convolution
58
+
59
+ group_output = group_input + residual
60
+
61
+ return group_output
62
+
63
+ class ResidualInResidual(nn.Module):
64
+ def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
65
+ num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
66
+
67
+ super(ResidualInResidual, self).__init__()
68
+
69
+ self.residual_groups = nn.Sequential(
70
+ *[ResidualGroup(num_residual_blocks=num_residual_blocks,
71
+ num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
72
+ for _ in range(num_residual_groups)]
73
+ )
74
+
75
+ self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
76
+
77
+ def forward(self, x):
78
+ shallow_feature = x.clone()
79
+
80
+ residual = self.residual_groups(x) # Residual groups
81
+ residual = self.final_conv(residual) # Final convolution
82
+
83
+ deep_feature = shallow_feature + residual
84
+
85
+ return deep_feature
86
+
87
+ class RCAN(nn.Module):
88
+ def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
89
+ num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
90
+
91
+ super(RCAN, self).__init__()
92
+
93
+ self.shallow_conv = nn.Conv2d(3, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
94
+ self.residual_in_residual = ResidualInResidual(num_residual_groups=num_residual_groups, num_residual_blocks=num_residual_blocks,
95
+ num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
96
+ self.upscaling_module = nn.PixelShuffle(upscale_factor=UPSCALE_FACTOR)
97
+ self.reconstruction_conv = nn.Conv2d(num_channels // (UPSCALE_FACTOR ** 2), 3, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
98
+
99
+ def forward(self, x):
100
+ shallow_feature = self.shallow_conv(x) # Initial convolution
101
+ deep_feature = self.residual_in_residual(shallow_feature) # Residual in Residual
102
+ upscaled_image = self.upscaling_module(deep_feature) # Upscaling module
103
+ reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
104
+
105
+ return reconstructed_image
models/RCAN/rcan_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60d3235f16777e31b98266bdf9e4bae13d0ede40edde176c1ea768c54ad737e6
3
+ size 39983995