import numpy as np from core.leras import nn tf = nn.tf class BlurPool(nn.LayerBase): def __init__(self, filt_size=3, stride=2, **kwargs ): if nn.data_format == "NHWC": self.strides = [1,stride,stride,1] else: self.strides = [1,1,stride,stride] self.filt_size = filt_size pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ] if nn.data_format == "NHWC": self.padding = [ [0,0], pad, pad, [0,0] ] else: self.padding = [ [0,0], [0,0], pad, pad ] if(self.filt_size==1): a = np.array([1.,]) elif(self.filt_size==2): a = np.array([1., 1.]) elif(self.filt_size==3): a = np.array([1., 2., 1.]) elif(self.filt_size==4): a = np.array([1., 3., 3., 1.]) elif(self.filt_size==5): a = np.array([1., 4., 6., 4., 1.]) elif(self.filt_size==6): a = np.array([1., 5., 10., 10., 5., 1.]) elif(self.filt_size==7): a = np.array([1., 6., 15., 20., 15., 6., 1.]) a = a[:,None]*a[None,:] a = a / np.sum(a) a = a[:,:,None,None] self.a = a super().__init__(**kwargs) def build_weights(self): self.k = tf.constant (self.a, dtype=nn.floatx ) def forward(self, x): k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) ) x = tf.pad(x, self.padding ) x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format) return x nn.BlurPool = BlurPool