SankarSrin's picture
Duplicate from vivym/image-matting-app
36239b8
raw
history blame contribute delete
No virus
12.1 kB
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import time
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddleseg
from paddleseg.models import layers
from paddleseg import utils
from paddleseg.cvlibs import manager
from ppmatting.models.losses import MRSD, GradientLoss
from ppmatting.models.backbone import resnet_vd
@manager.MODELS.add_component
class PPMatting(nn.Layer):
"""
The PPMattinh implementation based on PaddlePaddle.
The original article refers to
Guowei Chen, et, al. "PP-Matting: High-Accuracy Natural Image Matting"
(https://arxiv.org/pdf/2204.09433.pdf).
Args:
backbone: backbone model.
pretrained(str, optional): The path of pretrianed model. Defautl: None.
"""
def __init__(self, backbone, pretrained=None):
super().__init__()
self.backbone = backbone
self.pretrained = pretrained
self.loss_func_dict = self.get_loss_func_dict()
self.backbone_channels = backbone.feat_channels
self.scb = SCB(self.backbone_channels[-1])
self.hrdb = HRDB(
self.backbone_channels[0] + self.backbone_channels[1],
scb_channels=self.scb.out_channels,
gf_index=[0, 2, 4])
self.init_weight()
def forward(self, inputs):
x = inputs['img']
input_shape = paddle.shape(x)
fea_list = self.backbone(x)
scb_logits = self.scb(fea_list[-1])
semantic_map = F.softmax(scb_logits[-1], axis=1)
fea0 = F.interpolate(
fea_list[0], input_shape[2:], mode='bilinear', align_corners=False)
fea1 = F.interpolate(
fea_list[1], input_shape[2:], mode='bilinear', align_corners=False)
hrdb_input = paddle.concat([fea0, fea1], 1)
hrdb_logit = self.hrdb(hrdb_input, scb_logits)
detail_map = F.sigmoid(hrdb_logit)
fusion = self.fusion(semantic_map, detail_map)
if self.training:
logit_dict = {
'semantic': semantic_map,
'detail': detail_map,
'fusion': fusion
}
loss_dict = self.loss(logit_dict, inputs)
return logit_dict, loss_dict
else:
return fusion
def get_loss_func_dict(self):
loss_func_dict = defaultdict(list)
loss_func_dict['semantic'].append(nn.NLLLoss())
loss_func_dict['detail'].append(MRSD())
loss_func_dict['detail'].append(GradientLoss())
loss_func_dict['fusion'].append(MRSD())
loss_func_dict['fusion'].append(MRSD())
loss_func_dict['fusion'].append(GradientLoss())
return loss_func_dict
def loss(self, logit_dict, label_dict):
loss = {}
# semantic loss computation
# get semantic label
semantic_label = label_dict['trimap']
semantic_label_trans = (semantic_label == 128).astype('int64')
semantic_label_bg = (semantic_label == 0).astype('int64')
semantic_label = semantic_label_trans + semantic_label_bg * 2
loss_semantic = self.loss_func_dict['semantic'][0](
paddle.log(logit_dict['semantic'] + 1e-6),
semantic_label.squeeze(1))
loss['semantic'] = loss_semantic
# detail loss computation
transparent = label_dict['trimap'] == 128
detail_alpha_loss = self.loss_func_dict['detail'][0](
logit_dict['detail'], label_dict['alpha'], transparent)
# gradient loss
detail_gradient_loss = self.loss_func_dict['detail'][1](
logit_dict['detail'], label_dict['alpha'], transparent)
loss_detail = detail_alpha_loss + detail_gradient_loss
loss['detail'] = loss_detail
loss['detail_alpha'] = detail_alpha_loss
loss['detail_gradient'] = detail_gradient_loss
# fusion loss
loss_fusion_func = self.loss_func_dict['fusion']
# fusion_sigmoid loss
fusion_alpha_loss = loss_fusion_func[0](logit_dict['fusion'],
label_dict['alpha'])
# composion loss
comp_pred = logit_dict['fusion'] * label_dict['fg'] + (
1 - logit_dict['fusion']) * label_dict['bg']
comp_gt = label_dict['alpha'] * label_dict['fg'] + (
1 - label_dict['alpha']) * label_dict['bg']
fusion_composition_loss = loss_fusion_func[1](comp_pred, comp_gt)
# grandient loss
fusion_grad_loss = loss_fusion_func[2](logit_dict['fusion'],
label_dict['alpha'])
# fusion loss
loss_fusion = fusion_alpha_loss + fusion_composition_loss + fusion_grad_loss
loss['fusion'] = loss_fusion
loss['fusion_alpha'] = fusion_alpha_loss
loss['fusion_composition'] = fusion_composition_loss
loss['fusion_gradient'] = fusion_grad_loss
loss[
'all'] = 0.25 * loss_semantic + 0.25 * loss_detail + 0.25 * loss_fusion
return loss
def fusion(self, semantic_map, detail_map):
# semantic_map [N, 3, H, W]
# In index, 0 is foreground, 1 is transition, 2 is backbone
# After fusion, the foreground is 1, the background is 0, and the transion is between [0, 1]
index = paddle.argmax(semantic_map, axis=1, keepdim=True)
transition_mask = (index == 1).astype('float32')
fg = (index == 0).astype('float32')
alpha = detail_map * transition_mask + fg
return alpha
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
class SCB(nn.Layer):
def __init__(self, in_channels):
super().__init__()
self.in_channels = [512 + in_channels, 512, 256, 128, 128, 64]
self.mid_channels = [512, 256, 128, 128, 64, 64]
self.out_channels = [256, 128, 64, 64, 64, 3]
self.psp_module = layers.PPModule(
in_channels,
512,
bin_sizes=(1, 3, 5),
dim_reduction=False,
align_corners=False)
psp_upsamples = [2, 4, 8, 16]
self.psps = nn.LayerList([
self.conv_up_psp(512, self.out_channels[i], psp_upsamples[i])
for i in range(4)
])
scb_list = [
self._make_stage(
self.in_channels[i],
self.mid_channels[i],
self.out_channels[i],
padding=int(i == 0) + 1,
dilation=int(i == 0) + 1)
for i in range(len(self.in_channels) - 1)
]
scb_list += [
nn.Sequential(
layers.ConvBNReLU(
self.in_channels[-1], self.mid_channels[-1], 3, padding=1),
layers.ConvBNReLU(
self.mid_channels[-1], self.mid_channels[-1], 3, padding=1),
nn.Conv2D(
self.mid_channels[-1], self.out_channels[-1], 3, padding=1))
]
self.scb_stages = nn.LayerList(scb_list)
def forward(self, x):
psp_x = self.psp_module(x)
psps = [psp(psp_x) for psp in self.psps]
scb_logits = []
for i, scb_stage in enumerate(self.scb_stages):
if i == 0:
x = scb_stage(paddle.concat((psp_x, x), 1))
elif i <= len(psps):
x = scb_stage(paddle.concat((psps[i - 1], x), 1))
else:
x = scb_stage(x)
scb_logits.append(x)
return scb_logits
def conv_up_psp(self, in_channels, out_channels, up_sample):
return nn.Sequential(
layers.ConvBNReLU(
in_channels, out_channels, 3, padding=1),
nn.Upsample(
scale_factor=up_sample, mode='bilinear', align_corners=False))
def _make_stage(self,
in_channels,
mid_channels,
out_channels,
padding=1,
dilation=1):
layer_list = [
layers.ConvBNReLU(
in_channels, mid_channels, 3, padding=1), layers.ConvBNReLU(
mid_channels,
mid_channels,
3,
padding=padding,
dilation=dilation), layers.ConvBNReLU(
mid_channels,
out_channels,
3,
padding=padding,
dilation=dilation), nn.Upsample(
scale_factor=2,
mode='bilinear',
align_corners=False)
]
return nn.Sequential(*layer_list)
class HRDB(nn.Layer):
"""
The High-Resolution Detail Branch
Args:
in_channels(int): The number of input channels.
scb_channels(list|tuple): The channels of scb logits
gf_index(list|tuple, optional): Which logit is selected as guidance flow from scb logits. Default: (0, 2, 4)
"""
def __init__(self, in_channels, scb_channels, gf_index=(0, 2, 4)):
super().__init__()
self.gf_index = gf_index
self.gf_list = nn.LayerList(
[nn.Conv2D(scb_channels[i], 1, 1) for i in gf_index])
channels = [64, 32, 16, 8]
self.res_list = [
resnet_vd.BasicBlock(
in_channels, channels[0], stride=1, shortcut=False)
]
self.res_list += [
resnet_vd.BasicBlock(
i, i, stride=1) for i in channels[1:-1]
]
self.res_list = nn.LayerList(self.res_list)
self.convs = nn.LayerList([
nn.Conv2D(
channels[i], channels[i + 1], kernel_size=1)
for i in range(len(channels) - 1)
])
self.gates = nn.LayerList(
[GatedSpatailConv2d(i, i) for i in channels[1:]])
self.detail_conv = nn.Conv2D(channels[-1], 1, 1, bias_attr=False)
def forward(self, x, scb_logits):
for i in range(len(self.res_list)):
x = self.res_list[i](x)
x = self.convs[i](x)
gf = self.gf_list[i](scb_logits[self.gf_index[i]])
gf = F.interpolate(
gf, paddle.shape(x)[-2:], mode='bilinear', align_corners=False)
x = self.gates[i](x, gf)
return self.detail_conv(x)
class GatedSpatailConv2d(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
groups=1,
bias_attr=False):
super().__init__()
self._gate_conv = nn.Sequential(
layers.SyncBatchNorm(in_channels + 1),
nn.Conv2D(
in_channels + 1, in_channels + 1, kernel_size=1),
nn.ReLU(),
nn.Conv2D(
in_channels + 1, 1, kernel_size=1),
layers.SyncBatchNorm(1),
nn.Sigmoid())
self.conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias_attr)
def forward(self, input_features, gating_features):
cat = paddle.concat([input_features, gating_features], axis=1)
alphas = self._gate_conv(cat)
x = input_features * (alphas + 1)
x = self.conv(x)
return x