File size: 5,732 Bytes
fcd5579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from core.leras import nn
tf = nn.tf

class XSeg(nn.ModelBase):
    
    def on_build (self, in_ch, base_ch, out_ch):
        
        class ConvBlock(nn.ModelBase):
            def on_build(self, in_ch, out_ch):              
                self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME')
                self.frn = nn.FRNorm2D(out_ch)
                self.tlu = nn.TLU(out_ch)

            def forward(self, x):                
                x = self.conv(x)
                x = self.frn(x)
                x = self.tlu(x)
                return x

        class UpConvBlock(nn.ModelBase):
            def on_build(self, in_ch, out_ch):
                self.conv = nn.Conv2DTranspose (in_ch, out_ch, kernel_size=3, padding='SAME')
                self.frn = nn.FRNorm2D(out_ch)
                self.tlu = nn.TLU(out_ch)

            def forward(self, x):
                x = self.conv(x)
                x = self.frn(x)
                x = self.tlu(x)
                return x

        self.conv01 = ConvBlock(in_ch, base_ch)
        self.conv02 = ConvBlock(base_ch, base_ch)
        self.bp0 = nn.BlurPool (filt_size=3)


        self.conv11 = ConvBlock(base_ch, base_ch*2)
        self.conv12 = ConvBlock(base_ch*2, base_ch*2)
        self.bp1 = nn.BlurPool (filt_size=3)

        self.conv21 = ConvBlock(base_ch*2, base_ch*4)
        self.conv22 = ConvBlock(base_ch*4, base_ch*4)
        self.conv23 = ConvBlock(base_ch*4, base_ch*4)
        self.bp2 = nn.BlurPool (filt_size=3)


        self.conv31 = ConvBlock(base_ch*4, base_ch*8)
        self.conv32 = ConvBlock(base_ch*8, base_ch*8)
        self.conv33 = ConvBlock(base_ch*8, base_ch*8)
        self.bp3 = nn.BlurPool (filt_size=3)

        self.conv41 = ConvBlock(base_ch*8, base_ch*8)
        self.conv42 = ConvBlock(base_ch*8, base_ch*8)
        self.conv43 = ConvBlock(base_ch*8, base_ch*8)
        self.bp4 = nn.BlurPool (filt_size=3)
        
        self.up4 = UpConvBlock (base_ch*8, base_ch*4)
        self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
        self.uconv42 = ConvBlock(base_ch*8, base_ch*8)
        self.uconv41 = ConvBlock(base_ch*8, base_ch*8)

        self.up3 = UpConvBlock (base_ch*8, base_ch*4)
        self.uconv33 = ConvBlock(base_ch*12, base_ch*8)
        self.uconv32 = ConvBlock(base_ch*8, base_ch*8)
        self.uconv31 = ConvBlock(base_ch*8, base_ch*8)

        self.up2 = UpConvBlock (base_ch*8, base_ch*4)
        self.uconv23 = ConvBlock(base_ch*8, base_ch*4)
        self.uconv22 = ConvBlock(base_ch*4, base_ch*4)
        self.uconv21 = ConvBlock(base_ch*4, base_ch*4)

        self.up1 = UpConvBlock (base_ch*4, base_ch*2)
        self.uconv12 = ConvBlock(base_ch*4, base_ch*2)
        self.uconv11 = ConvBlock(base_ch*2, base_ch*2)

        self.up0 = UpConvBlock (base_ch*2, base_ch)
        self.uconv02 = ConvBlock(base_ch*2, base_ch)
        self.uconv01 = ConvBlock(base_ch, base_ch)
        
        self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
        
        self.conv_center = ConvBlock(base_ch*8, base_ch*8)
        
        #self.ae_latent_enc = nn.Dense( base_ch*8, 64 )
        #self.ae_latent_dec = nn.Dense( 64, base_ch*8 )
        
        #self.ae_up4 = nn.Conv2D( base_ch*8, base_ch*8 *4, kernel_size=3, padding='SAME')
        #self.ae_up3 = nn.Conv2D( base_ch*8, base_ch*8 *4, kernel_size=3, padding='SAME')
        #self.ae_up2 = nn.Conv2D( base_ch*8, base_ch*4 *4, kernel_size=3, padding='SAME')
        #self.ae_up1 = nn.Conv2D( base_ch*4, base_ch*2 *4, kernel_size=3, padding='SAME')
        #self.ae_up0 = nn.Conv2D( base_ch*2, base_ch   *4, kernel_size=3, padding='SAME')
        
        
        
    def forward(self, inp):
        x = inp

        x = self.conv01(x)
        x = x0 = self.conv02(x)
        x = self.bp0(x)

        x = self.conv11(x)
        x = x1 = self.conv12(x)
        x = self.bp1(x)

        x = self.conv21(x)
        x = self.conv22(x)
        x = x2 = self.conv23(x)
        x = self.bp2(x)

        x = self.conv31(x)
        x = self.conv32(x)
        x = x3 = self.conv33(x)
        x = self.bp3(x)

        x = self.conv41(x)
        x = self.conv42(x)
        x = x4 = self.conv43(x)
        x = self.bp4(x)

        ae_x = x = self.conv_center(x)
        
       
        
        x = self.up4(x)
        x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
        x = self.uconv42(x)
        x = self.uconv41(x)

        x = self.up3(x)
        x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
        x = self.uconv32(x)
        x = self.uconv31(x)

        x = self.up2(x)
        x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
        x = self.uconv22(x)
        x = self.uconv21(x)

        x = self.up1(x)
        x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
        x = self.uconv11(x)

        x = self.up0(x)
        x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
        x = self.uconv01(x)
        
        """
        ae_x = nn.flatten(x)
        ae_x = self.ae_latent_enc(ae_x)
        ae_x = self.ae_latent_dec(ae_x)
        ae_x = nn.reshape_4D (ae_x, 8, 8, 64)
        
        ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up4(ae_x), 0.1), 2)
        ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up3(ae_x), 0.1), 2)
        ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up2(ae_x), 0.1), 2)
        ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up1(ae_x), 0.1), 2)
        ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up0(ae_x), 0.1), 2)
        
        x = tf.concat([x,ae_x],axis=nn.conv2d_ch_axis)
        """
        
        logits = self.out_conv(x)
        return logits, tf.nn.sigmoid(logits)

nn.XSeg = XSeg