Spaces:
Running
Running
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from collections import defaultdict | |
import time | |
import paddle | |
import paddle.nn as nn | |
import paddle.nn.functional as F | |
import paddleseg | |
from paddleseg.models import layers | |
from paddleseg import utils | |
from paddleseg.cvlibs import manager | |
from ppmatting.models.losses import MRSD | |
def conv_up_psp(in_channels, out_channels, up_sample): | |
return nn.Sequential( | |
layers.ConvBNReLU( | |
in_channels, out_channels, 3, padding=1), | |
nn.Upsample( | |
scale_factor=up_sample, mode='bilinear', align_corners=False)) | |
class HumanMatting(nn.Layer): | |
"""A model for """ | |
def __init__(self, | |
backbone, | |
pretrained=None, | |
backbone_scale=0.25, | |
refine_kernel_size=3, | |
if_refine=True): | |
super().__init__() | |
if if_refine: | |
if backbone_scale > 0.5: | |
raise ValueError( | |
'Backbone_scale should not be greater than 1/2, but it is {}' | |
.format(backbone_scale)) | |
else: | |
backbone_scale = 1 | |
self.backbone = backbone | |
self.backbone_scale = backbone_scale | |
self.pretrained = pretrained | |
self.if_refine = if_refine | |
if if_refine: | |
self.refiner = Refiner(kernel_size=refine_kernel_size) | |
self.loss_func_dict = None | |
self.backbone_channels = backbone.feat_channels | |
###################### | |
### Decoder part - Glance | |
###################### | |
self.psp_module = layers.PPModule( | |
self.backbone_channels[-1], | |
512, | |
bin_sizes=(1, 3, 5), | |
dim_reduction=False, | |
align_corners=False) | |
self.psp4 = conv_up_psp(512, 256, 2) | |
self.psp3 = conv_up_psp(512, 128, 4) | |
self.psp2 = conv_up_psp(512, 64, 8) | |
self.psp1 = conv_up_psp(512, 64, 16) | |
# stage 5g | |
self.decoder5_g = nn.Sequential( | |
layers.ConvBNReLU( | |
512 + self.backbone_channels[-1], 512, 3, padding=1), | |
layers.ConvBNReLU( | |
512, 512, 3, padding=2, dilation=2), | |
layers.ConvBNReLU( | |
512, 256, 3, padding=2, dilation=2), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 4g | |
self.decoder4_g = nn.Sequential( | |
layers.ConvBNReLU( | |
512, 256, 3, padding=1), | |
layers.ConvBNReLU( | |
256, 256, 3, padding=1), | |
layers.ConvBNReLU( | |
256, 128, 3, padding=1), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 3g | |
self.decoder3_g = nn.Sequential( | |
layers.ConvBNReLU( | |
256, 128, 3, padding=1), | |
layers.ConvBNReLU( | |
128, 128, 3, padding=1), | |
layers.ConvBNReLU( | |
128, 64, 3, padding=1), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 2g | |
self.decoder2_g = nn.Sequential( | |
layers.ConvBNReLU( | |
128, 128, 3, padding=1), | |
layers.ConvBNReLU( | |
128, 128, 3, padding=1), | |
layers.ConvBNReLU( | |
128, 64, 3, padding=1), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 1g | |
self.decoder1_g = nn.Sequential( | |
layers.ConvBNReLU( | |
128, 64, 3, padding=1), | |
layers.ConvBNReLU( | |
64, 64, 3, padding=1), | |
layers.ConvBNReLU( | |
64, 64, 3, padding=1), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 0g | |
self.decoder0_g = nn.Sequential( | |
layers.ConvBNReLU( | |
64, 64, 3, padding=1), | |
layers.ConvBNReLU( | |
64, 64, 3, padding=1), | |
nn.Conv2D( | |
64, 3, 3, padding=1)) | |
########################## | |
### Decoder part - FOCUS | |
########################## | |
self.bridge_block = nn.Sequential( | |
layers.ConvBNReLU( | |
self.backbone_channels[-1], 512, 3, dilation=2, padding=2), | |
layers.ConvBNReLU( | |
512, 512, 3, dilation=2, padding=2), | |
layers.ConvBNReLU( | |
512, 512, 3, dilation=2, padding=2)) | |
# stage 5f | |
self.decoder5_f = nn.Sequential( | |
layers.ConvBNReLU( | |
512 + self.backbone_channels[-1], 512, 3, padding=1), | |
layers.ConvBNReLU( | |
512, 512, 3, padding=2, dilation=2), | |
layers.ConvBNReLU( | |
512, 256, 3, padding=2, dilation=2), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 4f | |
self.decoder4_f = nn.Sequential( | |
layers.ConvBNReLU( | |
256 + self.backbone_channels[-2], 256, 3, padding=1), | |
layers.ConvBNReLU( | |
256, 256, 3, padding=1), | |
layers.ConvBNReLU( | |
256, 128, 3, padding=1), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 3f | |
self.decoder3_f = nn.Sequential( | |
layers.ConvBNReLU( | |
128 + self.backbone_channels[-3], 128, 3, padding=1), | |
layers.ConvBNReLU( | |
128, 128, 3, padding=1), | |
layers.ConvBNReLU( | |
128, 64, 3, padding=1), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 2f | |
self.decoder2_f = nn.Sequential( | |
layers.ConvBNReLU( | |
64 + self.backbone_channels[-4], 128, 3, padding=1), | |
layers.ConvBNReLU( | |
128, 128, 3, padding=1), | |
layers.ConvBNReLU( | |
128, 64, 3, padding=1), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 1f | |
self.decoder1_f = nn.Sequential( | |
layers.ConvBNReLU( | |
64 + self.backbone_channels[-5], 64, 3, padding=1), | |
layers.ConvBNReLU( | |
64, 64, 3, padding=1), | |
layers.ConvBNReLU( | |
64, 64, 3, padding=1), | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
# stage 0f | |
self.decoder0_f = nn.Sequential( | |
layers.ConvBNReLU( | |
64, 64, 3, padding=1), | |
layers.ConvBNReLU( | |
64, 64, 3, padding=1), | |
nn.Conv2D( | |
64, 1 + 1 + 32, 3, padding=1)) | |
self.init_weight() | |
def forward(self, data): | |
src = data['img'] | |
src_h, src_w = paddle.shape(src)[2:] | |
if self.if_refine: | |
# It is not need when exporting. | |
if isinstance(src_h, paddle.Tensor): | |
if (src_h % 4 != 0) or (src_w % 4) != 0: | |
raise ValueError( | |
'The input image must have width and height that are divisible by 4' | |
) | |
# Downsample src for backbone | |
src_sm = F.interpolate( | |
src, | |
scale_factor=self.backbone_scale, | |
mode='bilinear', | |
align_corners=False) | |
# Base | |
fea_list = self.backbone(src_sm) | |
########################## | |
### Decoder part - GLANCE | |
########################## | |
#psp: N, 512, H/32, W/32 | |
psp = self.psp_module(fea_list[-1]) | |
#d6_g: N, 512, H/16, W/16 | |
d5_g = self.decoder5_g(paddle.concat((psp, fea_list[-1]), 1)) | |
#d5_g: N, 512, H/8, W/8 | |
d4_g = self.decoder4_g(paddle.concat((self.psp4(psp), d5_g), 1)) | |
#d4_g: N, 256, H/4, W/4 | |
d3_g = self.decoder3_g(paddle.concat((self.psp3(psp), d4_g), 1)) | |
#d4_g: N, 128, H/2, W/2 | |
d2_g = self.decoder2_g(paddle.concat((self.psp2(psp), d3_g), 1)) | |
#d2_g: N, 64, H, W | |
d1_g = self.decoder1_g(paddle.concat((self.psp1(psp), d2_g), 1)) | |
#d0_g: N, 3, H, W | |
d0_g = self.decoder0_g(d1_g) | |
# The 1st channel is foreground. The 2nd is transition region. The 3rd is background. | |
# glance_sigmoid = F.sigmoid(d0_g) | |
glance_sigmoid = F.softmax(d0_g, axis=1) | |
########################## | |
### Decoder part - FOCUS | |
########################## | |
bb = self.bridge_block(fea_list[-1]) | |
#bg: N, 512, H/32, W/32 | |
d5_f = self.decoder5_f(paddle.concat((bb, fea_list[-1]), 1)) | |
#d5_f: N, 256, H/16, W/16 | |
d4_f = self.decoder4_f(paddle.concat((d5_f, fea_list[-2]), 1)) | |
#d4_f: N, 128, H/8, W/8 | |
d3_f = self.decoder3_f(paddle.concat((d4_f, fea_list[-3]), 1)) | |
#d3_f: N, 64, H/4, W/4 | |
d2_f = self.decoder2_f(paddle.concat((d3_f, fea_list[-4]), 1)) | |
#d2_f: N, 64, H/2, W/2 | |
d1_f = self.decoder1_f(paddle.concat((d2_f, fea_list[-5]), 1)) | |
#d1_f: N, 64, H, W | |
d0_f = self.decoder0_f(d1_f) | |
#d0_f: N, 1, H, W | |
focus_sigmoid = F.sigmoid(d0_f[:, 0:1, :, :]) | |
pha_sm = self.fusion(glance_sigmoid, focus_sigmoid) | |
err_sm = d0_f[:, 1:2, :, :] | |
err_sm = paddle.clip(err_sm, 0., 1.) | |
hid_sm = F.relu(d0_f[:, 2:, :, :]) | |
# Refiner | |
if self.if_refine: | |
pha = self.refiner( | |
src=src, pha=pha_sm, err=err_sm, hid=hid_sm, tri=glance_sigmoid) | |
# Clamp outputs | |
pha = paddle.clip(pha, 0., 1.) | |
if self.training: | |
logit_dict = { | |
'glance': glance_sigmoid, | |
'focus': focus_sigmoid, | |
'fusion': pha_sm, | |
'error': err_sm | |
} | |
if self.if_refine: | |
logit_dict['refine'] = pha | |
loss_dict = self.loss(logit_dict, data) | |
return logit_dict, loss_dict | |
else: | |
return pha if self.if_refine else pha_sm | |
def loss(self, logit_dict, label_dict, loss_func_dict=None): | |
if loss_func_dict is None: | |
if self.loss_func_dict is None: | |
self.loss_func_dict = defaultdict(list) | |
self.loss_func_dict['glance'].append(nn.NLLLoss()) | |
self.loss_func_dict['focus'].append(MRSD()) | |
self.loss_func_dict['cm'].append(MRSD()) | |
self.loss_func_dict['err'].append(paddleseg.models.MSELoss()) | |
self.loss_func_dict['refine'].append(paddleseg.models.L1Loss()) | |
else: | |
self.loss_func_dict = loss_func_dict | |
loss = {} | |
# glance loss computation | |
# get glance label | |
glance_label = F.interpolate( | |
label_dict['trimap'], | |
logit_dict['glance'].shape[2:], | |
mode='nearest', | |
align_corners=False) | |
glance_label_trans = (glance_label == 128).astype('int64') | |
glance_label_bg = (glance_label == 0).astype('int64') | |
glance_label = glance_label_trans + glance_label_bg * 2 | |
loss_glance = self.loss_func_dict['glance'][0]( | |
paddle.log(logit_dict['glance'] + 1e-6), glance_label.squeeze(1)) | |
loss['glance'] = loss_glance | |
# focus loss computation | |
focus_label = F.interpolate( | |
label_dict['alpha'], | |
logit_dict['focus'].shape[2:], | |
mode='bilinear', | |
align_corners=False) | |
loss_focus = self.loss_func_dict['focus'][0]( | |
logit_dict['focus'], focus_label, glance_label_trans) | |
loss['focus'] = loss_focus | |
# collaborative matting loss | |
loss_cm_func = self.loss_func_dict['cm'] | |
# fusion_sigmoid loss | |
loss_cm = loss_cm_func[0](logit_dict['fusion'], focus_label) | |
loss['cm'] = loss_cm | |
# error loss | |
err = F.interpolate( | |
logit_dict['error'], | |
label_dict['alpha'].shape[2:], | |
mode='bilinear', | |
align_corners=False) | |
err_label = (F.interpolate( | |
logit_dict['fusion'], | |
label_dict['alpha'].shape[2:], | |
mode='bilinear', | |
align_corners=False) - label_dict['alpha']).abs() | |
loss_err = self.loss_func_dict['err'][0](err, err_label) | |
loss['err'] = loss_err | |
loss_all = 0.25 * loss_glance + 0.25 * loss_focus + 0.25 * loss_cm + loss_err | |
# refine loss | |
if self.if_refine: | |
loss_refine = self.loss_func_dict['refine'][0](logit_dict['refine'], | |
label_dict['alpha']) | |
loss['refine'] = loss_refine | |
loss_all = loss_all + loss_refine | |
loss['all'] = loss_all | |
return loss | |
def fusion(self, glance_sigmoid, focus_sigmoid): | |
# glance_sigmoid [N, 3, H, W]. | |
# In index, 0 is foreground, 1 is transition, 2 is backbone. | |
# After fusion, the foreground is 1, the background is 0, and the transion is between (0, 1). | |
index = paddle.argmax(glance_sigmoid, axis=1, keepdim=True) | |
transition_mask = (index == 1).astype('float32') | |
fg = (index == 0).astype('float32') | |
fusion_sigmoid = focus_sigmoid * transition_mask + fg | |
return fusion_sigmoid | |
def init_weight(self): | |
if self.pretrained is not None: | |
utils.load_entire_model(self, self.pretrained) | |
class Refiner(nn.Layer): | |
''' | |
Refiner refines the coarse output to full resolution. | |
Args: | |
kernel_size: The convolution kernel_size. Options: [1, 3]. Default: 3. | |
''' | |
def __init__(self, kernel_size=3): | |
super().__init__() | |
if kernel_size not in [1, 3]: | |
raise ValueError("kernel_size must be in [1, 3]") | |
self.kernel_size = kernel_size | |
channels = [32, 24, 16, 12, 1] | |
self.conv1 = layers.ConvBNReLU( | |
channels[0] + 4 + 3, | |
channels[1], | |
kernel_size, | |
padding=0, | |
bias_attr=False) | |
self.conv2 = layers.ConvBNReLU( | |
channels[1], channels[2], kernel_size, padding=0, bias_attr=False) | |
self.conv3 = layers.ConvBNReLU( | |
channels[2] + 3, | |
channels[3], | |
kernel_size, | |
padding=0, | |
bias_attr=False) | |
self.conv4 = nn.Conv2D( | |
channels[3], channels[4], kernel_size, padding=0, bias_attr=True) | |
def forward(self, src, pha, err, hid, tri): | |
''' | |
Args: | |
src: (B, 3, H, W) full resolution source image. | |
pha: (B, 1, Hc, Wc) coarse alpha prediction. | |
err: (B, 1, Hc, Hc) coarse error prediction. | |
hid: (B, 32, Hc, Hc) coarse hidden encoding. | |
tri: (B, 1, Hc, Hc) trimap prediction. | |
''' | |
h_full, w_full = paddle.shape(src)[2:] | |
h_half, w_half = h_full // 2, w_full // 2 | |
h_quat, w_quat = h_full // 4, w_full // 4 | |
x = paddle.concat([hid, pha, tri], axis=1) | |
x = F.interpolate( | |
x, | |
paddle.concat((h_half, w_half)), | |
mode='bilinear', | |
align_corners=False) | |
y = F.interpolate( | |
src, | |
paddle.concat((h_half, w_half)), | |
mode='bilinear', | |
align_corners=False) | |
if self.kernel_size == 3: | |
x = F.pad(x, [3, 3, 3, 3]) | |
y = F.pad(y, [3, 3, 3, 3]) | |
x = self.conv1(paddle.concat([x, y], axis=1)) | |
x = self.conv2(x) | |
if self.kernel_size == 3: | |
x = F.interpolate(x, paddle.concat((h_full + 4, w_full + 4))) | |
y = F.pad(src, [2, 2, 2, 2]) | |
else: | |
x = F.interpolate( | |
x, paddle.concat((h_full, w_full)), mode='nearest') | |
y = src | |
x = self.conv3(paddle.concat([x, y], axis=1)) | |
x = self.conv4(x) | |
pha = x | |
return pha | |