SankarSrin's picture
Duplicate from vivym/image-matting-app
36239b8
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import time
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddleseg
from paddleseg.models import layers
from paddleseg import utils
from paddleseg.cvlibs import manager
from ppmatting.models.losses import MRSD
def conv_up_psp(in_channels, out_channels, up_sample):
return nn.Sequential(
layers.ConvBNReLU(
in_channels, out_channels, 3, padding=1),
nn.Upsample(
scale_factor=up_sample, mode='bilinear', align_corners=False))
@manager.MODELS.add_component
class HumanMatting(nn.Layer):
"""A model for """
def __init__(self,
backbone,
pretrained=None,
backbone_scale=0.25,
refine_kernel_size=3,
if_refine=True):
super().__init__()
if if_refine:
if backbone_scale > 0.5:
raise ValueError(
'Backbone_scale should not be greater than 1/2, but it is {}'
.format(backbone_scale))
else:
backbone_scale = 1
self.backbone = backbone
self.backbone_scale = backbone_scale
self.pretrained = pretrained
self.if_refine = if_refine
if if_refine:
self.refiner = Refiner(kernel_size=refine_kernel_size)
self.loss_func_dict = None
self.backbone_channels = backbone.feat_channels
######################
### Decoder part - Glance
######################
self.psp_module = layers.PPModule(
self.backbone_channels[-1],
512,
bin_sizes=(1, 3, 5),
dim_reduction=False,
align_corners=False)
self.psp4 = conv_up_psp(512, 256, 2)
self.psp3 = conv_up_psp(512, 128, 4)
self.psp2 = conv_up_psp(512, 64, 8)
self.psp1 = conv_up_psp(512, 64, 16)
# stage 5g
self.decoder5_g = nn.Sequential(
layers.ConvBNReLU(
512 + self.backbone_channels[-1], 512, 3, padding=1),
layers.ConvBNReLU(
512, 512, 3, padding=2, dilation=2),
layers.ConvBNReLU(
512, 256, 3, padding=2, dilation=2),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 4g
self.decoder4_g = nn.Sequential(
layers.ConvBNReLU(
512, 256, 3, padding=1),
layers.ConvBNReLU(
256, 256, 3, padding=1),
layers.ConvBNReLU(
256, 128, 3, padding=1),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 3g
self.decoder3_g = nn.Sequential(
layers.ConvBNReLU(
256, 128, 3, padding=1),
layers.ConvBNReLU(
128, 128, 3, padding=1),
layers.ConvBNReLU(
128, 64, 3, padding=1),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 2g
self.decoder2_g = nn.Sequential(
layers.ConvBNReLU(
128, 128, 3, padding=1),
layers.ConvBNReLU(
128, 128, 3, padding=1),
layers.ConvBNReLU(
128, 64, 3, padding=1),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 1g
self.decoder1_g = nn.Sequential(
layers.ConvBNReLU(
128, 64, 3, padding=1),
layers.ConvBNReLU(
64, 64, 3, padding=1),
layers.ConvBNReLU(
64, 64, 3, padding=1),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 0g
self.decoder0_g = nn.Sequential(
layers.ConvBNReLU(
64, 64, 3, padding=1),
layers.ConvBNReLU(
64, 64, 3, padding=1),
nn.Conv2D(
64, 3, 3, padding=1))
##########################
### Decoder part - FOCUS
##########################
self.bridge_block = nn.Sequential(
layers.ConvBNReLU(
self.backbone_channels[-1], 512, 3, dilation=2, padding=2),
layers.ConvBNReLU(
512, 512, 3, dilation=2, padding=2),
layers.ConvBNReLU(
512, 512, 3, dilation=2, padding=2))
# stage 5f
self.decoder5_f = nn.Sequential(
layers.ConvBNReLU(
512 + self.backbone_channels[-1], 512, 3, padding=1),
layers.ConvBNReLU(
512, 512, 3, padding=2, dilation=2),
layers.ConvBNReLU(
512, 256, 3, padding=2, dilation=2),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 4f
self.decoder4_f = nn.Sequential(
layers.ConvBNReLU(
256 + self.backbone_channels[-2], 256, 3, padding=1),
layers.ConvBNReLU(
256, 256, 3, padding=1),
layers.ConvBNReLU(
256, 128, 3, padding=1),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 3f
self.decoder3_f = nn.Sequential(
layers.ConvBNReLU(
128 + self.backbone_channels[-3], 128, 3, padding=1),
layers.ConvBNReLU(
128, 128, 3, padding=1),
layers.ConvBNReLU(
128, 64, 3, padding=1),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 2f
self.decoder2_f = nn.Sequential(
layers.ConvBNReLU(
64 + self.backbone_channels[-4], 128, 3, padding=1),
layers.ConvBNReLU(
128, 128, 3, padding=1),
layers.ConvBNReLU(
128, 64, 3, padding=1),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 1f
self.decoder1_f = nn.Sequential(
layers.ConvBNReLU(
64 + self.backbone_channels[-5], 64, 3, padding=1),
layers.ConvBNReLU(
64, 64, 3, padding=1),
layers.ConvBNReLU(
64, 64, 3, padding=1),
nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False))
# stage 0f
self.decoder0_f = nn.Sequential(
layers.ConvBNReLU(
64, 64, 3, padding=1),
layers.ConvBNReLU(
64, 64, 3, padding=1),
nn.Conv2D(
64, 1 + 1 + 32, 3, padding=1))
self.init_weight()
def forward(self, data):
src = data['img']
src_h, src_w = paddle.shape(src)[2:]
if self.if_refine:
# It is not need when exporting.
if isinstance(src_h, paddle.Tensor):
if (src_h % 4 != 0) or (src_w % 4) != 0:
raise ValueError(
'The input image must have width and height that are divisible by 4'
)
# Downsample src for backbone
src_sm = F.interpolate(
src,
scale_factor=self.backbone_scale,
mode='bilinear',
align_corners=False)
# Base
fea_list = self.backbone(src_sm)
##########################
### Decoder part - GLANCE
##########################
#psp: N, 512, H/32, W/32
psp = self.psp_module(fea_list[-1])
#d6_g: N, 512, H/16, W/16
d5_g = self.decoder5_g(paddle.concat((psp, fea_list[-1]), 1))
#d5_g: N, 512, H/8, W/8
d4_g = self.decoder4_g(paddle.concat((self.psp4(psp), d5_g), 1))
#d4_g: N, 256, H/4, W/4
d3_g = self.decoder3_g(paddle.concat((self.psp3(psp), d4_g), 1))
#d4_g: N, 128, H/2, W/2
d2_g = self.decoder2_g(paddle.concat((self.psp2(psp), d3_g), 1))
#d2_g: N, 64, H, W
d1_g = self.decoder1_g(paddle.concat((self.psp1(psp), d2_g), 1))
#d0_g: N, 3, H, W
d0_g = self.decoder0_g(d1_g)
# The 1st channel is foreground. The 2nd is transition region. The 3rd is background.
# glance_sigmoid = F.sigmoid(d0_g)
glance_sigmoid = F.softmax(d0_g, axis=1)
##########################
### Decoder part - FOCUS
##########################
bb = self.bridge_block(fea_list[-1])
#bg: N, 512, H/32, W/32
d5_f = self.decoder5_f(paddle.concat((bb, fea_list[-1]), 1))
#d5_f: N, 256, H/16, W/16
d4_f = self.decoder4_f(paddle.concat((d5_f, fea_list[-2]), 1))
#d4_f: N, 128, H/8, W/8
d3_f = self.decoder3_f(paddle.concat((d4_f, fea_list[-3]), 1))
#d3_f: N, 64, H/4, W/4
d2_f = self.decoder2_f(paddle.concat((d3_f, fea_list[-4]), 1))
#d2_f: N, 64, H/2, W/2
d1_f = self.decoder1_f(paddle.concat((d2_f, fea_list[-5]), 1))
#d1_f: N, 64, H, W
d0_f = self.decoder0_f(d1_f)
#d0_f: N, 1, H, W
focus_sigmoid = F.sigmoid(d0_f[:, 0:1, :, :])
pha_sm = self.fusion(glance_sigmoid, focus_sigmoid)
err_sm = d0_f[:, 1:2, :, :]
err_sm = paddle.clip(err_sm, 0., 1.)
hid_sm = F.relu(d0_f[:, 2:, :, :])
# Refiner
if self.if_refine:
pha = self.refiner(
src=src, pha=pha_sm, err=err_sm, hid=hid_sm, tri=glance_sigmoid)
# Clamp outputs
pha = paddle.clip(pha, 0., 1.)
if self.training:
logit_dict = {
'glance': glance_sigmoid,
'focus': focus_sigmoid,
'fusion': pha_sm,
'error': err_sm
}
if self.if_refine:
logit_dict['refine'] = pha
loss_dict = self.loss(logit_dict, data)
return logit_dict, loss_dict
else:
return pha if self.if_refine else pha_sm
def loss(self, logit_dict, label_dict, loss_func_dict=None):
if loss_func_dict is None:
if self.loss_func_dict is None:
self.loss_func_dict = defaultdict(list)
self.loss_func_dict['glance'].append(nn.NLLLoss())
self.loss_func_dict['focus'].append(MRSD())
self.loss_func_dict['cm'].append(MRSD())
self.loss_func_dict['err'].append(paddleseg.models.MSELoss())
self.loss_func_dict['refine'].append(paddleseg.models.L1Loss())
else:
self.loss_func_dict = loss_func_dict
loss = {}
# glance loss computation
# get glance label
glance_label = F.interpolate(
label_dict['trimap'],
logit_dict['glance'].shape[2:],
mode='nearest',
align_corners=False)
glance_label_trans = (glance_label == 128).astype('int64')
glance_label_bg = (glance_label == 0).astype('int64')
glance_label = glance_label_trans + glance_label_bg * 2
loss_glance = self.loss_func_dict['glance'][0](
paddle.log(logit_dict['glance'] + 1e-6), glance_label.squeeze(1))
loss['glance'] = loss_glance
# focus loss computation
focus_label = F.interpolate(
label_dict['alpha'],
logit_dict['focus'].shape[2:],
mode='bilinear',
align_corners=False)
loss_focus = self.loss_func_dict['focus'][0](
logit_dict['focus'], focus_label, glance_label_trans)
loss['focus'] = loss_focus
# collaborative matting loss
loss_cm_func = self.loss_func_dict['cm']
# fusion_sigmoid loss
loss_cm = loss_cm_func[0](logit_dict['fusion'], focus_label)
loss['cm'] = loss_cm
# error loss
err = F.interpolate(
logit_dict['error'],
label_dict['alpha'].shape[2:],
mode='bilinear',
align_corners=False)
err_label = (F.interpolate(
logit_dict['fusion'],
label_dict['alpha'].shape[2:],
mode='bilinear',
align_corners=False) - label_dict['alpha']).abs()
loss_err = self.loss_func_dict['err'][0](err, err_label)
loss['err'] = loss_err
loss_all = 0.25 * loss_glance + 0.25 * loss_focus + 0.25 * loss_cm + loss_err
# refine loss
if self.if_refine:
loss_refine = self.loss_func_dict['refine'][0](logit_dict['refine'],
label_dict['alpha'])
loss['refine'] = loss_refine
loss_all = loss_all + loss_refine
loss['all'] = loss_all
return loss
def fusion(self, glance_sigmoid, focus_sigmoid):
# glance_sigmoid [N, 3, H, W].
# In index, 0 is foreground, 1 is transition, 2 is backbone.
# After fusion, the foreground is 1, the background is 0, and the transion is between (0, 1).
index = paddle.argmax(glance_sigmoid, axis=1, keepdim=True)
transition_mask = (index == 1).astype('float32')
fg = (index == 0).astype('float32')
fusion_sigmoid = focus_sigmoid * transition_mask + fg
return fusion_sigmoid
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
class Refiner(nn.Layer):
'''
Refiner refines the coarse output to full resolution.
Args:
kernel_size: The convolution kernel_size. Options: [1, 3]. Default: 3.
'''
def __init__(self, kernel_size=3):
super().__init__()
if kernel_size not in [1, 3]:
raise ValueError("kernel_size must be in [1, 3]")
self.kernel_size = kernel_size
channels = [32, 24, 16, 12, 1]
self.conv1 = layers.ConvBNReLU(
channels[0] + 4 + 3,
channels[1],
kernel_size,
padding=0,
bias_attr=False)
self.conv2 = layers.ConvBNReLU(
channels[1], channels[2], kernel_size, padding=0, bias_attr=False)
self.conv3 = layers.ConvBNReLU(
channels[2] + 3,
channels[3],
kernel_size,
padding=0,
bias_attr=False)
self.conv4 = nn.Conv2D(
channels[3], channels[4], kernel_size, padding=0, bias_attr=True)
def forward(self, src, pha, err, hid, tri):
'''
Args:
src: (B, 3, H, W) full resolution source image.
pha: (B, 1, Hc, Wc) coarse alpha prediction.
err: (B, 1, Hc, Hc) coarse error prediction.
hid: (B, 32, Hc, Hc) coarse hidden encoding.
tri: (B, 1, Hc, Hc) trimap prediction.
'''
h_full, w_full = paddle.shape(src)[2:]
h_half, w_half = h_full // 2, w_full // 2
h_quat, w_quat = h_full // 4, w_full // 4
x = paddle.concat([hid, pha, tri], axis=1)
x = F.interpolate(
x,
paddle.concat((h_half, w_half)),
mode='bilinear',
align_corners=False)
y = F.interpolate(
src,
paddle.concat((h_half, w_half)),
mode='bilinear',
align_corners=False)
if self.kernel_size == 3:
x = F.pad(x, [3, 3, 3, 3])
y = F.pad(y, [3, 3, 3, 3])
x = self.conv1(paddle.concat([x, y], axis=1))
x = self.conv2(x)
if self.kernel_size == 3:
x = F.interpolate(x, paddle.concat((h_full + 4, w_full + 4)))
y = F.pad(src, [2, 2, 2, 2])
else:
x = F.interpolate(
x, paddle.concat((h_full, w_full)), mode='nearest')
y = src
x = self.conv3(paddle.concat([x, y], axis=1))
x = self.conv4(x)
pha = x
return pha